kurbo/
common.rs

1// Copyright 2018 the Kurbo Authors
2// SPDX-License-Identifier: Apache-2.0 OR MIT
3
4//! Common mathematical operations
5
6#![allow(missing_docs)]
7
8#[cfg(not(feature = "std"))]
9mod sealed {
10    /// A [sealed trait](https://predr.ag/blog/definitive-guide-to-sealed-traits-in-rust/)
11    /// which stops [`super::FloatFuncs`] from being implemented outside kurbo. This could
12    /// be relaxed in the future if there is are good reasons to allow external impls.
13    /// The benefit from being sealed is that we can add methods without breaking downstream
14    /// implementations.
15    pub trait FloatFuncsSealed {}
16}
17
18use arrayvec::ArrayVec;
19
20/// Defines a trait that chooses between libstd or libm implementations of float methods.
21macro_rules! define_float_funcs {
22    ($(
23        fn $name:ident(self $(,$arg:ident: $arg_ty:ty)*) -> $ret:ty
24        => $lname:ident/$lfname:ident;
25    )+) => {
26
27        /// Since core doesn't depend upon libm, this provides libm implementations
28        /// of float functions which are typically provided by the std library, when
29        /// the `std` feature is not enabled.
30        ///
31        /// For documentation see the respective functions in the std library.
32        #[cfg(not(feature = "std"))]
33        pub trait FloatFuncs : Sized + sealed::FloatFuncsSealed {
34            /// Special implementation for signum, because libm doesn't have it.
35            fn signum(self) -> Self;
36
37            $(fn $name(self $(,$arg: $arg_ty)*) -> $ret;)+
38        }
39
40        #[cfg(not(feature = "std"))]
41        impl sealed::FloatFuncsSealed for f32 {}
42
43        #[cfg(not(feature = "std"))]
44        impl FloatFuncs for f32 {
45            #[inline]
46            fn signum(self) -> f32 {
47                if self.is_nan() {
48                    f32::NAN
49                } else {
50                    1.0_f32.copysign(self)
51                }
52            }
53
54            $(fn $name(self $(,$arg: $arg_ty)*) -> $ret {
55                #[cfg(feature = "libm")]
56                return libm::$lfname(self $(,$arg as _)*);
57
58                #[cfg(not(feature = "libm"))]
59                compile_error!("kurbo requires either the `std` or `libm` feature")
60            })+
61        }
62
63        #[cfg(not(feature = "std"))]
64        impl sealed::FloatFuncsSealed for f64 {}
65        #[cfg(not(feature = "std"))]
66        impl FloatFuncs for f64 {
67            #[inline]
68            fn signum(self) -> f64 {
69                if self.is_nan() {
70                    f64::NAN
71                } else {
72                    1.0_f64.copysign(self)
73                }
74            }
75
76            $(fn $name(self $(,$arg: $arg_ty)*) -> $ret {
77                #[cfg(feature = "libm")]
78                return libm::$lname(self $(,$arg as _)*);
79
80                #[cfg(not(feature = "libm"))]
81                compile_error!("kurbo requires either the `std` or `libm` feature")
82            })+
83        }
84    }
85}
86
87define_float_funcs! {
88    fn abs(self) -> Self => fabs/fabsf;
89    fn acos(self) -> Self => acos/acosf;
90    fn atan2(self, other: Self) -> Self => atan2/atan2f;
91    fn cbrt(self) -> Self => cbrt/cbrtf;
92    fn ceil(self) -> Self => ceil/ceilf;
93    fn cos(self) -> Self => cos/cosf;
94    fn copysign(self, sign: Self) -> Self => copysign/copysignf;
95    fn floor(self) -> Self => floor/floorf;
96    fn hypot(self, other: Self) -> Self => hypot/hypotf;
97    fn ln(self) -> Self => log/logf;
98    fn log2(self) -> Self => log2/log2f;
99    fn mul_add(self, a: Self, b: Self) -> Self => fma/fmaf;
100    fn powi(self, n: i32) -> Self => pow/powf;
101    fn powf(self, n: Self) -> Self => pow/powf;
102    fn round(self) -> Self => round/roundf;
103    fn sin(self) -> Self => sin/sinf;
104    fn sin_cos(self) -> (Self, Self) => sincos/sincosf;
105    fn sqrt(self) -> Self => sqrt/sqrtf;
106    fn tan(self) -> Self => tan/tanf;
107    fn trunc(self) -> Self => trunc/truncf;
108}
109
110/// Adds convenience methods to `f32` and `f64`.
111pub trait FloatExt<T> {
112    /// Rounds to the nearest integer away from zero,
113    /// unless the provided value is already an integer.
114    ///
115    /// It is to `ceil` what `trunc` is to `floor`.
116    ///
117    /// # Examples
118    ///
119    /// ```
120    /// use kurbo::common::FloatExt;
121    ///
122    /// let f = 3.7_f64;
123    /// let g = 3.0_f64;
124    /// let h = -3.7_f64;
125    /// let i = -5.1_f32;
126    ///
127    /// assert_eq!(f.expand(), 4.0);
128    /// assert_eq!(g.expand(), 3.0);
129    /// assert_eq!(h.expand(), -4.0);
130    /// assert_eq!(i.expand(), -6.0);
131    /// ```
132    fn expand(&self) -> T;
133}
134
135impl FloatExt<f64> for f64 {
136    #[inline]
137    fn expand(&self) -> f64 {
138        self.abs().ceil().copysign(*self)
139    }
140}
141
142impl FloatExt<f32> for f32 {
143    #[inline]
144    fn expand(&self) -> f32 {
145        self.abs().ceil().copysign(*self)
146    }
147}
148
149/// Find real roots of cubic equation.
150///
151/// The implementation is not (yet) fully robust, but it does handle the case
152/// where `c3` is zero (in that case, solving the quadratic equation).
153///
154/// See: <https://momentsingraphics.de/CubicRoots.html>
155///
156/// That implementation is in turn based on Jim Blinn's "How to Solve a Cubic
157/// Equation", which is masterful.
158///
159/// Return values of x for which c0 + c1 x + c2 x² + c3 x³ = 0.
160pub fn solve_cubic(c0: f64, c1: f64, c2: f64, c3: f64) -> ArrayVec<f64, 3> {
161    let mut result = ArrayVec::new();
162    let c3_recip = c3.recip();
163    const ONETHIRD: f64 = 1. / 3.;
164    let scaled_c2 = c2 * (ONETHIRD * c3_recip);
165    let scaled_c1 = c1 * (ONETHIRD * c3_recip);
166    let scaled_c0 = c0 * c3_recip;
167    if !(scaled_c0.is_finite() && scaled_c1.is_finite() && scaled_c2.is_finite()) {
168        // cubic coefficient is zero or nearly so.
169        return solve_quadratic(c0, c1, c2).iter().copied().collect();
170    }
171    let (c0, c1, c2) = (scaled_c0, scaled_c1, scaled_c2);
172    // (d0, d1, d2) is called "Delta" in article
173    let d0 = (-c2).mul_add(c2, c1);
174    let d1 = (-c1).mul_add(c2, c0);
175    let d2 = c2 * c0 - c1 * c1;
176    // d is called "Discriminant"
177    let d = 4.0 * d0 * d2 - d1 * d1;
178    // de is called "Depressed.x", Depressed.y = d0
179    let de = (-2.0 * c2).mul_add(d0, d1);
180    // TODO: handle the cases where these intermediate results overflow.
181    if d < 0.0 {
182        let sq = (-0.25 * d).sqrt();
183        let r = -0.5 * de;
184        let t1 = (r + sq).cbrt() + (r - sq).cbrt();
185        result.push(t1 - c2);
186    } else if d == 0.0 {
187        let t1 = (-d0).sqrt().copysign(de);
188        result.push(t1 - c2);
189        result.push(-2.0 * t1 - c2);
190    } else {
191        let th = d.sqrt().atan2(-de) * ONETHIRD;
192        // (th_cos, th_sin) is called "CubicRoot"
193        let (th_sin, th_cos) = th.sin_cos();
194        // (r0, r1, r2) is called "Root"
195        let r0 = th_cos;
196        let ss3 = th_sin * 3.0f64.sqrt();
197        let r1 = 0.5 * (-th_cos + ss3);
198        let r2 = 0.5 * (-th_cos - ss3);
199        let t = 2.0 * (-d0).sqrt();
200        result.push(t.mul_add(r0, -c2));
201        result.push(t.mul_add(r1, -c2));
202        result.push(t.mul_add(r2, -c2));
203    }
204    result
205}
206
207/// Find real roots of quadratic equation.
208///
209/// Return values of x for which c0 + c1 x + c2 x² = 0.
210///
211/// This function tries to be quite numerically robust. If the equation
212/// is nearly linear, it will return the root ignoring the quadratic term;
213/// the other root might be out of representable range. In the degenerate
214/// case where all coefficients are zero, so that all values of x satisfy
215/// the equation, a single `0.0` is returned.
216pub fn solve_quadratic(c0: f64, c1: f64, c2: f64) -> ArrayVec<f64, 2> {
217    let mut result = ArrayVec::new();
218    let sc0 = c0 * c2.recip();
219    let sc1 = c1 * c2.recip();
220    if !sc0.is_finite() || !sc1.is_finite() {
221        // c2 is zero or very small, treat as linear eqn
222        let root = -c0 / c1;
223        if root.is_finite() {
224            result.push(root);
225        } else if c0 == 0.0 && c1 == 0.0 {
226            // Degenerate case
227            result.push(0.0);
228        }
229        return result;
230    }
231    let arg = sc1 * sc1 - 4. * sc0;
232    let root1 = if !arg.is_finite() {
233        // Likely, calculation of sc1 * sc1 overflowed. Find one root
234        // using sc1 x + x² = 0, other root as sc0 / root1.
235        -sc1
236    } else {
237        if arg < 0.0 {
238            return result;
239        } else if arg == 0.0 {
240            result.push(-0.5 * sc1);
241            return result;
242        }
243        // See https://math.stackexchange.com/questions/866331
244        -0.5 * (sc1 + arg.sqrt().copysign(sc1))
245    };
246    let root2 = sc0 / root1;
247    if root2.is_finite() {
248        // Sort just to be friendly and make results deterministic.
249        if root2 > root1 {
250            result.push(root1);
251            result.push(root2);
252        } else {
253            result.push(root2);
254            result.push(root1);
255        }
256    } else {
257        result.push(root1);
258    }
259    result
260}
261
262/// Compute epsilon relative to coefficient.
263///
264/// A helper function from the Orellana and De Michele paper.
265fn eps_rel(raw: f64, a: f64) -> f64 {
266    if a == 0.0 {
267        raw.abs()
268    } else {
269        ((raw - a) / a).abs()
270    }
271}
272
273/// Find real roots of a quartic equation.
274///
275/// This is a fairly literal implementation of the method described in:
276/// Algorithm 1010: Boosting Efficiency in Solving Quartic Equations with
277/// No Compromise in Accuracy, Orellana and De Michele, ACM
278/// Transactions on Mathematical Software, Vol. 46, No. 2, May 2020.
279pub fn solve_quartic(c0: f64, c1: f64, c2: f64, c3: f64, c4: f64) -> ArrayVec<f64, 4> {
280    if c4 == 0.0 {
281        return solve_cubic(c0, c1, c2, c3).iter().copied().collect();
282    }
283    if c0 == 0.0 {
284        // Note: appends 0 root at end, doesn't sort. We might want to do that.
285        return solve_cubic(c1, c2, c3, c4)
286            .iter()
287            .copied()
288            .chain(Some(0.0))
289            .collect();
290    }
291    let a = c3 / c4;
292    let b = c2 / c4;
293    let c = c1 / c4;
294    let d = c0 / c4;
295    if let Some(result) = solve_quartic_inner(a, b, c, d, false) {
296        return result;
297    }
298    // Do polynomial rescaling
299    const K_Q: f64 = 7.16e76;
300    for rescale in [false, true] {
301        if let Some(result) = solve_quartic_inner(
302            a / K_Q,
303            b / K_Q.powi(2),
304            c / K_Q.powi(3),
305            d / K_Q.powi(4),
306            rescale,
307        ) {
308            return result.iter().map(|x| x * K_Q).collect();
309        }
310    }
311    // Overflow happened, just return no roots.
312    //println!("overflow, no roots returned");
313    ArrayVec::default()
314}
315
316fn solve_quartic_inner(a: f64, b: f64, c: f64, d: f64, rescale: bool) -> Option<ArrayVec<f64, 4>> {
317    factor_quartic_inner(a, b, c, d, rescale).map(|quadratics| {
318        quadratics
319            .iter()
320            .flat_map(|(a, b)| solve_quadratic(*b, *a, 1.0))
321            .collect()
322    })
323}
324
325/// Factor a quartic into two quadratics.
326///
327/// Attempt to factor a quartic equation into two quadratic equations. Returns `None` either if there
328/// is overflow (in which case rescaling might succeed) or the factorization would result in
329/// complex coefficients.
330///
331/// Discussion question: distinguish the two cases in return value?
332pub fn factor_quartic_inner(
333    a: f64,
334    b: f64,
335    c: f64,
336    d: f64,
337    rescale: bool,
338) -> Option<ArrayVec<(f64, f64), 2>> {
339    let calc_eps_q = |a1, b1, a2, b2| {
340        let eps_a = eps_rel(a1 + a2, a);
341        let eps_b = eps_rel(b1 + a1 * a2 + b2, b);
342        let eps_c = eps_rel(b1 * a2 + a1 * b2, c);
343        eps_a + eps_b + eps_c
344    };
345    let calc_eps_t = |a1, b1, a2, b2| calc_eps_q(a1, b1, a2, b2) + eps_rel(b1 * b2, d);
346    let disc = 9. * a * a - 24. * b;
347    let s = if disc >= 0.0 {
348        -2. * b / (3. * a + disc.sqrt().copysign(a))
349    } else {
350        -0.25 * a
351    };
352    let a_prime = a + 4. * s;
353    let b_prime = b + 3. * s * (a + 2. * s);
354    let c_prime = c + s * (2. * b + s * (3. * a + 4. * s));
355    let d_prime = d + s * (c + s * (b + s * (a + s)));
356    let g_prime;
357    let h_prime;
358    const K_C: f64 = 3.49e102;
359    if rescale {
360        let a_prime_s = a_prime / K_C;
361        let b_prime_s = b_prime / K_C;
362        let c_prime_s = c_prime / K_C;
363        let d_prime_s = d_prime / K_C;
364        g_prime = a_prime_s * c_prime_s - (4. / K_C) * d_prime_s - (1. / 3.) * b_prime_s.powi(2);
365        h_prime = (a_prime_s * c_prime_s + (8. / K_C) * d_prime_s - (2. / 9.) * b_prime_s.powi(2))
366            * (1. / 3.)
367            * b_prime_s
368            - c_prime_s * (c_prime_s / K_C)
369            - a_prime_s.powi(2) * d_prime_s;
370    } else {
371        g_prime = a_prime * c_prime - 4. * d_prime - (1. / 3.) * b_prime.powi(2);
372        h_prime =
373            (a_prime * c_prime + 8. * d_prime - (2. / 9.) * b_prime.powi(2)) * (1. / 3.) * b_prime
374                - c_prime.powi(2)
375                - a_prime.powi(2) * d_prime;
376    }
377    if !(g_prime.is_finite() && h_prime.is_finite()) {
378        return None;
379    }
380    let phi = depressed_cubic_dominant(g_prime, h_prime);
381    let phi = if rescale { phi * K_C } else { phi };
382    let l_1 = a * 0.5;
383    let l_3 = (1. / 6.) * b + 0.5 * phi;
384    let delt_2 = c - a * l_3;
385    let d_2_cand_1 = (2. / 3.) * b - phi - l_1 * l_1;
386    let l_2_cand_1 = 0.5 * delt_2 / d_2_cand_1;
387    let l_2_cand_2 = 2. * (d - l_3 * l_3) / delt_2;
388    let d_2_cand_2 = 0.5 * delt_2 / l_2_cand_2;
389    let d_2_cand_3 = d_2_cand_1;
390    let l_2_cand_3 = l_2_cand_2;
391    let mut d_2_best = 0.0;
392    let mut l_2_best = 0.0;
393    let mut eps_l_best = 0.0;
394    for (i, (d_2, l_2)) in [
395        (d_2_cand_1, l_2_cand_1),
396        (d_2_cand_2, l_2_cand_2),
397        (d_2_cand_3, l_2_cand_3),
398    ]
399    .iter()
400    .enumerate()
401    {
402        let eps_0 = eps_rel(d_2 + l_1 * l_1 + 2. * l_3, b);
403        let eps_1 = eps_rel(2. * (d_2 * l_2 + l_1 * l_3), c);
404        let eps_2 = eps_rel(d_2 * l_2 * l_2 + l_3 * l_3, d);
405        let eps_l = eps_0 + eps_1 + eps_2;
406        if i == 0 || eps_l < eps_l_best {
407            d_2_best = *d_2;
408            l_2_best = *l_2;
409            eps_l_best = eps_l;
410        }
411    }
412    let d_2 = d_2_best;
413    let l_2 = l_2_best;
414    let mut alpha_1;
415    let mut beta_1;
416    let mut alpha_2;
417    let mut beta_2;
418    //println!("phi = {}, d_2 = {}", phi, d_2);
419    if d_2 < 0.0 {
420        let sq = (-d_2).sqrt();
421        alpha_1 = l_1 + sq;
422        beta_1 = l_3 + sq * l_2;
423        alpha_2 = l_1 - sq;
424        beta_2 = l_3 - sq * l_2;
425        if beta_2.abs() < beta_1.abs() {
426            beta_2 = d / beta_1;
427        } else if beta_2.abs() > beta_1.abs() {
428            beta_1 = d / beta_2;
429        }
430        let cands;
431        if alpha_1.abs() != alpha_2.abs() {
432            if alpha_1.abs() < alpha_2.abs() {
433                let a1_cand_1 = (c - beta_1 * alpha_2) / beta_2;
434                let a1_cand_2 = (b - beta_2 - beta_1) / alpha_2;
435                let a1_cand_3 = a - alpha_2;
436                // Note: cand 3 is first because it is infallible, simplifying logic
437                cands = [
438                    (a1_cand_3, alpha_2),
439                    (a1_cand_1, alpha_2),
440                    (a1_cand_2, alpha_2),
441                ];
442            } else {
443                let a2_cand_1 = (c - alpha_1 * beta_2) / beta_1;
444                let a2_cand_2 = (b - beta_2 - beta_1) / alpha_1;
445                let a2_cand_3 = a - alpha_1;
446                cands = [
447                    (alpha_1, a2_cand_3),
448                    (alpha_1, a2_cand_1),
449                    (alpha_1, a2_cand_2),
450                ];
451            }
452            let mut eps_q_best = 0.0;
453            for (i, (a1, a2)) in cands.iter().enumerate() {
454                if a1.is_finite() && a2.is_finite() {
455                    let eps_q = calc_eps_q(*a1, beta_1, *a2, beta_2);
456                    if i == 0 || eps_q < eps_q_best {
457                        alpha_1 = *a1;
458                        alpha_2 = *a2;
459                        eps_q_best = eps_q;
460                    }
461                }
462            }
463        }
464    } else if d_2 == 0.0 {
465        let d_3 = d - l_3 * l_3;
466        alpha_1 = l_1;
467        beta_1 = l_3 + (-d_3).sqrt();
468        alpha_2 = l_1;
469        beta_2 = l_3 - (-d_3).sqrt();
470        if beta_1.abs() > beta_2.abs() {
471            beta_2 = d / beta_1;
472        } else if beta_2.abs() > beta_1.abs() {
473            beta_1 = d / beta_2;
474        }
475        // TODO: handle case d_2 is very small?
476    } else {
477        // This case means no real roots; in the most general case we might want
478        // to factor into quadratic equations with complex coefficients.
479        return None;
480    }
481    // Newton-Raphson iteration on alpha/beta coeff's.
482    let mut eps_t = calc_eps_t(alpha_1, beta_1, alpha_2, beta_2);
483    for _ in 0..8 {
484        //println!("a1 {} b1 {} a2 {} b2 {}", alpha_1, beta_1, alpha_2, beta_2);
485        //println!("eps_t = {:e}", eps_t);
486        if eps_t == 0.0 {
487            break;
488        }
489        let f_0 = beta_1 * beta_2 - d;
490        let f_1 = beta_1 * alpha_2 + alpha_1 * beta_2 - c;
491        let f_2 = beta_1 + alpha_1 * alpha_2 + beta_2 - b;
492        let f_3 = alpha_1 + alpha_2 - a;
493        let c_1 = alpha_1 - alpha_2;
494        let det_j = beta_1 * beta_1 - beta_1 * (alpha_2 * c_1 + 2. * beta_2)
495            + beta_2 * (alpha_1 * c_1 + beta_2);
496        if det_j == 0.0 {
497            break;
498        }
499        let inv = det_j.recip();
500        let c_2 = beta_2 - beta_1;
501        let c_3 = beta_1 * alpha_2 - alpha_1 * beta_2;
502        let dz_0 = c_1 * f_0 + c_2 * f_1 + c_3 * f_2 - (beta_1 * c_2 + alpha_1 * c_3) * f_3;
503        let dz_1 = (alpha_1 * c_1 + c_2) * f_0
504            - beta_1 * c_1 * f_1
505            - beta_1 * c_2 * f_2
506            - beta_1 * c_3 * f_3;
507        let dz_2 = -c_1 * f_0 - c_2 * f_1 - c_3 * f_2 + (alpha_2 * c_3 + beta_2 * c_2) * f_3;
508        let dz_3 = -(alpha_2 * c_1 + c_2) * f_0
509            + beta_2 * c_1 * f_1
510            + beta_2 * c_2 * f_2
511            + beta_2 * c_3 * f_3;
512        let a1 = alpha_1 - inv * dz_0;
513        let b1 = beta_1 - inv * dz_1;
514        let a2 = alpha_2 - inv * dz_2;
515        let b2 = beta_2 - inv * dz_3;
516        let new_eps_t = calc_eps_t(a1, b1, a2, b2);
517        // We break if the new eps is equal, paper keeps going
518        if new_eps_t < eps_t {
519            alpha_1 = a1;
520            beta_1 = b1;
521            alpha_2 = a2;
522            beta_2 = b2;
523            eps_t = new_eps_t;
524        } else {
525            //println!("new_eps_t got worse: {:e}", new_eps_t);
526            break;
527        }
528    }
529    Some([(alpha_1, beta_1), (alpha_2, beta_2)].into())
530}
531
532/// Dominant root of depressed cubic x^3 + gx + h = 0.
533///
534/// Section 2.2 of Orellana and De Michele.
535// Note: some of the techniques in here might be useful to improve the
536// cubic solver, and vice versa.
537fn depressed_cubic_dominant(g: f64, h: f64) -> f64 {
538    let q = (-1. / 3.) * g;
539    let r = 0.5 * h;
540    let phi_0;
541    let k = if q.abs() < 1e102 && r.abs() < 1e154 {
542        None
543    } else if q.abs() < r.abs() {
544        Some(1. - q * (q / r).powi(2))
545    } else {
546        Some(q.signum() * ((r / q).powi(2) / q - 1.0))
547    };
548    if k.is_some() && r == 0.0 {
549        if g > 0.0 {
550            phi_0 = 0.0;
551        } else {
552            phi_0 = (-g).sqrt();
553        }
554    } else if k.map(|k| k < 0.0).unwrap_or_else(|| r * r < q.powi(3)) {
555        let t = if k.is_some() {
556            r / q / q.sqrt()
557        } else {
558            r / q.powi(3).sqrt()
559        };
560        phi_0 = -2. * q.sqrt() * (t.abs().acos() * (1. / 3.)).cos().copysign(t);
561    } else {
562        let a = if let Some(k) = k {
563            if q.abs() < r.abs() {
564                -r * (1. + k.sqrt())
565            } else {
566                -r - (q.abs().sqrt() * q * k.sqrt()).copysign(r)
567            }
568        } else {
569            -r - (r * r - q.powi(3)).sqrt().copysign(r)
570        }
571        .cbrt();
572        let b = if a == 0.0 { 0.0 } else { q / a };
573        phi_0 = a + b;
574    }
575    // Refine with Newton-Raphson iteration
576    let mut x = phi_0;
577    let mut f = (x * x + g) * x + h;
578    //println!("g = {:e}, h = {:e}, x = {:e}, f = {:e}", g, h, x, f);
579    const EPS_M: f64 = 2.22045e-16;
580    if f.abs() < EPS_M * x.powi(3).max(g * x).max(h) {
581        return x;
582    }
583    for _ in 0..8 {
584        let delt_f = 3. * x * x + g;
585        if delt_f == 0.0 {
586            break;
587        }
588        let new_x = x - f / delt_f;
589        let new_f = (new_x * new_x + g) * new_x + h;
590        //println!("delt_f = {:e}, new_f = {:e}", delt_f, new_f);
591        if new_f == 0.0 {
592            return new_x;
593        }
594        if new_f.abs() >= f.abs() {
595            break;
596        }
597        x = new_x;
598        f = new_f;
599    }
600    x
601}
602
603/// Solve an arbitrary function for a zero-crossing.
604///
605/// This uses the [ITP method], as described in the paper
606/// [An Enhancement of the Bisection Method Average Performance Preserving Minmax Optimality].
607///
608/// The values of `ya` and `yb` are given as arguments rather than
609/// computed from `f`, as the values may already be known, or they may
610/// be less expensive to compute as special cases.
611///
612/// It is assumed that `ya < 0.0` and `yb > 0.0`, otherwise unexpected
613/// results may occur.
614///
615/// The value of `epsilon` must be larger than 2^-63 times `b - a`,
616/// otherwise integer overflow may occur. The `a` and `b` parameters
617/// represent the lower and upper bounds of the bracket searched for a
618/// solution.
619///
620/// The ITP method has tuning parameters. This implementation hardwires
621/// k2 to 2, both because it avoids an expensive floating point
622/// exponentiation, and because this value has been tested to work well
623/// with curve fitting problems.
624///
625/// The `n0` parameter controls the relative impact of the bisection and
626/// secant components. When it is 0, the number of iterations is
627/// guaranteed to be no more than the number required by bisection (thus,
628/// this method is strictly superior to bisection). However, when the
629/// function is smooth, a value of 1 gives the secant method more of a
630/// chance to engage, so the average number of iterations is likely
631/// lower, though there can be one more iteration than bisection in the
632/// worst case.
633///
634/// The `k1` parameter is harder to characterize, and interested users
635/// are referred to the paper, as well as encouraged to do empirical
636/// testing. To match the paper, a value of `0.2 / (b - a)` is
637/// suggested, and this is confirmed to give good results.
638///
639/// When the function is monotonic, the returned result is guaranteed to
640/// be within `epsilon` of the zero crossing. For more detailed analysis,
641/// again see the paper.
642///
643/// [ITP method]: https://en.wikipedia.org/wiki/ITP_Method
644/// [An Enhancement of the Bisection Method Average Performance Preserving Minmax Optimality]: https://dl.acm.org/doi/10.1145/3423597
645#[allow(clippy::too_many_arguments)]
646pub fn solve_itp(
647    mut f: impl FnMut(f64) -> f64,
648    mut a: f64,
649    mut b: f64,
650    epsilon: f64,
651    n0: usize,
652    k1: f64,
653    mut ya: f64,
654    mut yb: f64,
655) -> f64 {
656    let n1_2 = (((b - a) / epsilon).log2().ceil() - 1.0).max(0.0) as usize;
657    let nmax = n0 + n1_2;
658    let mut scaled_epsilon = epsilon * (1u64 << nmax) as f64;
659    while b - a > 2.0 * epsilon {
660        let x1_2 = 0.5 * (a + b);
661        let r = scaled_epsilon - 0.5 * (b - a);
662        let xf = (yb * a - ya * b) / (yb - ya);
663        let sigma = x1_2 - xf;
664        // This has k2 = 2 hardwired for efficiency.
665        let delta = k1 * (b - a).powi(2);
666        let xt = if delta <= (x1_2 - xf).abs() {
667            xf + delta.copysign(sigma)
668        } else {
669            x1_2
670        };
671        let xitp = if (xt - x1_2).abs() <= r {
672            xt
673        } else {
674            x1_2 - r.copysign(sigma)
675        };
676        let yitp = f(xitp);
677        if yitp > 0.0 {
678            b = xitp;
679            yb = yitp;
680        } else if yitp < 0.0 {
681            a = xitp;
682            ya = yitp;
683        } else {
684            return xitp;
685        }
686        scaled_epsilon *= 0.5;
687    }
688    0.5 * (a + b)
689}
690
691/// A variant ITP solver that allows fallible functions.
692///
693/// Another difference: it returns the bracket that contains the root,
694/// which may be important if the function has a discontinuity.
695#[allow(clippy::too_many_arguments)]
696pub(crate) fn solve_itp_fallible<E>(
697    mut f: impl FnMut(f64) -> Result<f64, E>,
698    mut a: f64,
699    mut b: f64,
700    epsilon: f64,
701    n0: usize,
702    k1: f64,
703    mut ya: f64,
704    mut yb: f64,
705) -> Result<(f64, f64), E> {
706    let n1_2 = (((b - a) / epsilon).log2().ceil() - 1.0).max(0.0) as usize;
707    let nmax = n0 + n1_2;
708    let mut scaled_epsilon = epsilon * (1u64 << nmax) as f64;
709    while b - a > 2.0 * epsilon {
710        let x1_2 = 0.5 * (a + b);
711        let r = scaled_epsilon - 0.5 * (b - a);
712        let xf = (yb * a - ya * b) / (yb - ya);
713        let sigma = x1_2 - xf;
714        // This has k2 = 2 hardwired for efficiency.
715        let delta = k1 * (b - a).powi(2);
716        let xt = if delta <= (x1_2 - xf).abs() {
717            xf + delta.copysign(sigma)
718        } else {
719            x1_2
720        };
721        let xitp = if (xt - x1_2).abs() <= r {
722            xt
723        } else {
724            x1_2 - r.copysign(sigma)
725        };
726        let yitp = f(xitp)?;
727        if yitp > 0.0 {
728            b = xitp;
729            yb = yitp;
730        } else if yitp < 0.0 {
731            a = xitp;
732            ya = yitp;
733        } else {
734            return Ok((xitp, xitp));
735        }
736        scaled_epsilon *= 0.5;
737    }
738    Ok((a, b))
739}
740
741// Tables of Legendre-Gauss quadrature coefficients, adapted from:
742// <https://pomax.github.io/bezierinfo/legendre-gauss.html>
743
744pub const GAUSS_LEGENDRE_COEFFS_3: &[(f64, f64)] = &[
745    (0.8888888888888888, 0.0000000000000000),
746    (0.5555555555555556, -0.7745966692414834),
747    (0.5555555555555556, 0.7745966692414834),
748];
749
750pub const GAUSS_LEGENDRE_COEFFS_4: &[(f64, f64)] = &[
751    (0.6521451548625461, -0.3399810435848563),
752    (0.6521451548625461, 0.3399810435848563),
753    (0.3478548451374538, -0.8611363115940526),
754    (0.3478548451374538, 0.8611363115940526),
755];
756
757pub const GAUSS_LEGENDRE_COEFFS_5: &[(f64, f64)] = &[
758    (0.5688888888888889, 0.0000000000000000),
759    (0.4786286704993665, -0.5384693101056831),
760    (0.4786286704993665, 0.5384693101056831),
761    (0.2369268850561891, -0.9061798459386640),
762    (0.2369268850561891, 0.9061798459386640),
763];
764
765pub const GAUSS_LEGENDRE_COEFFS_6: &[(f64, f64)] = &[
766    (0.3607615730481386, 0.6612093864662645),
767    (0.3607615730481386, -0.6612093864662645),
768    (0.4679139345726910, -0.2386191860831969),
769    (0.4679139345726910, 0.2386191860831969),
770    (0.1713244923791704, -0.9324695142031521),
771    (0.1713244923791704, 0.9324695142031521),
772];
773
774pub const GAUSS_LEGENDRE_COEFFS_7: &[(f64, f64)] = &[
775    (0.4179591836734694, 0.0000000000000000),
776    (0.3818300505051189, 0.4058451513773972),
777    (0.3818300505051189, -0.4058451513773972),
778    (0.2797053914892766, -0.7415311855993945),
779    (0.2797053914892766, 0.7415311855993945),
780    (0.1294849661688697, -0.9491079123427585),
781    (0.1294849661688697, 0.9491079123427585),
782];
783
784pub const GAUSS_LEGENDRE_COEFFS_8: &[(f64, f64)] = &[
785    (0.3626837833783620, -0.1834346424956498),
786    (0.3626837833783620, 0.1834346424956498),
787    (0.3137066458778873, -0.5255324099163290),
788    (0.3137066458778873, 0.5255324099163290),
789    (0.2223810344533745, -0.7966664774136267),
790    (0.2223810344533745, 0.7966664774136267),
791    (0.1012285362903763, -0.9602898564975363),
792    (0.1012285362903763, 0.9602898564975363),
793];
794
795pub const GAUSS_LEGENDRE_COEFFS_8_HALF: &[(f64, f64)] = &[
796    (0.3626837833783620, 0.1834346424956498),
797    (0.3137066458778873, 0.5255324099163290),
798    (0.2223810344533745, 0.7966664774136267),
799    (0.1012285362903763, 0.9602898564975363),
800];
801
802pub const GAUSS_LEGENDRE_COEFFS_9: &[(f64, f64)] = &[
803    (0.3302393550012598, 0.0000000000000000),
804    (0.1806481606948574, -0.8360311073266358),
805    (0.1806481606948574, 0.8360311073266358),
806    (0.0812743883615744, -0.9681602395076261),
807    (0.0812743883615744, 0.9681602395076261),
808    (0.3123470770400029, -0.3242534234038089),
809    (0.3123470770400029, 0.3242534234038089),
810    (0.2606106964029354, -0.6133714327005904),
811    (0.2606106964029354, 0.6133714327005904),
812];
813
814pub const GAUSS_LEGENDRE_COEFFS_11: &[(f64, f64)] = &[
815    (0.2729250867779006, 0.0000000000000000),
816    (0.2628045445102467, -0.2695431559523450),
817    (0.2628045445102467, 0.2695431559523450),
818    (0.2331937645919905, -0.5190961292068118),
819    (0.2331937645919905, 0.5190961292068118),
820    (0.1862902109277343, -0.7301520055740494),
821    (0.1862902109277343, 0.7301520055740494),
822    (0.1255803694649046, -0.8870625997680953),
823    (0.1255803694649046, 0.8870625997680953),
824    (0.0556685671161737, -0.9782286581460570),
825    (0.0556685671161737, 0.9782286581460570),
826];
827
828pub const GAUSS_LEGENDRE_COEFFS_16: &[(f64, f64)] = &[
829    (0.1894506104550685, -0.0950125098376374),
830    (0.1894506104550685, 0.0950125098376374),
831    (0.1826034150449236, -0.2816035507792589),
832    (0.1826034150449236, 0.2816035507792589),
833    (0.1691565193950025, -0.4580167776572274),
834    (0.1691565193950025, 0.4580167776572274),
835    (0.1495959888165767, -0.6178762444026438),
836    (0.1495959888165767, 0.6178762444026438),
837    (0.1246289712555339, -0.7554044083550030),
838    (0.1246289712555339, 0.7554044083550030),
839    (0.0951585116824928, -0.8656312023878318),
840    (0.0951585116824928, 0.8656312023878318),
841    (0.0622535239386479, -0.9445750230732326),
842    (0.0622535239386479, 0.9445750230732326),
843    (0.0271524594117541, -0.9894009349916499),
844    (0.0271524594117541, 0.9894009349916499),
845];
846
847// Just the positive x_i values.
848pub const GAUSS_LEGENDRE_COEFFS_16_HALF: &[(f64, f64)] = &[
849    (0.1894506104550685, 0.0950125098376374),
850    (0.1826034150449236, 0.2816035507792589),
851    (0.1691565193950025, 0.4580167776572274),
852    (0.1495959888165767, 0.6178762444026438),
853    (0.1246289712555339, 0.7554044083550030),
854    (0.0951585116824928, 0.8656312023878318),
855    (0.0622535239386479, 0.9445750230732326),
856    (0.0271524594117541, 0.9894009349916499),
857];
858
859pub const GAUSS_LEGENDRE_COEFFS_24: &[(f64, f64)] = &[
860    (0.1279381953467522, -0.0640568928626056),
861    (0.1279381953467522, 0.0640568928626056),
862    (0.1258374563468283, -0.1911188674736163),
863    (0.1258374563468283, 0.1911188674736163),
864    (0.1216704729278034, -0.3150426796961634),
865    (0.1216704729278034, 0.3150426796961634),
866    (0.1155056680537256, -0.4337935076260451),
867    (0.1155056680537256, 0.4337935076260451),
868    (0.1074442701159656, -0.5454214713888396),
869    (0.1074442701159656, 0.5454214713888396),
870    (0.0976186521041139, -0.6480936519369755),
871    (0.0976186521041139, 0.6480936519369755),
872    (0.0861901615319533, -0.7401241915785544),
873    (0.0861901615319533, 0.7401241915785544),
874    (0.0733464814110803, -0.8200019859739029),
875    (0.0733464814110803, 0.8200019859739029),
876    (0.0592985849154368, -0.8864155270044011),
877    (0.0592985849154368, 0.8864155270044011),
878    (0.0442774388174198, -0.9382745520027328),
879    (0.0442774388174198, 0.9382745520027328),
880    (0.0285313886289337, -0.9747285559713095),
881    (0.0285313886289337, 0.9747285559713095),
882    (0.0123412297999872, -0.9951872199970213),
883    (0.0123412297999872, 0.9951872199970213),
884];
885
886pub const GAUSS_LEGENDRE_COEFFS_24_HALF: &[(f64, f64)] = &[
887    (0.1279381953467522, 0.0640568928626056),
888    (0.1258374563468283, 0.1911188674736163),
889    (0.1216704729278034, 0.3150426796961634),
890    (0.1155056680537256, 0.4337935076260451),
891    (0.1074442701159656, 0.5454214713888396),
892    (0.0976186521041139, 0.6480936519369755),
893    (0.0861901615319533, 0.7401241915785544),
894    (0.0733464814110803, 0.8200019859739029),
895    (0.0592985849154368, 0.8864155270044011),
896    (0.0442774388174198, 0.9382745520027328),
897    (0.0285313886289337, 0.9747285559713095),
898    (0.0123412297999872, 0.9951872199970213),
899];
900
901pub const GAUSS_LEGENDRE_COEFFS_32: &[(f64, f64)] = &[
902    (0.0965400885147278, -0.0483076656877383),
903    (0.0965400885147278, 0.0483076656877383),
904    (0.0956387200792749, -0.1444719615827965),
905    (0.0956387200792749, 0.1444719615827965),
906    (0.0938443990808046, -0.2392873622521371),
907    (0.0938443990808046, 0.2392873622521371),
908    (0.0911738786957639, -0.3318686022821277),
909    (0.0911738786957639, 0.3318686022821277),
910    (0.0876520930044038, -0.4213512761306353),
911    (0.0876520930044038, 0.4213512761306353),
912    (0.0833119242269467, -0.5068999089322294),
913    (0.0833119242269467, 0.5068999089322294),
914    (0.0781938957870703, -0.5877157572407623),
915    (0.0781938957870703, 0.5877157572407623),
916    (0.0723457941088485, -0.6630442669302152),
917    (0.0723457941088485, 0.6630442669302152),
918    (0.0658222227763618, -0.7321821187402897),
919    (0.0658222227763618, 0.7321821187402897),
920    (0.0586840934785355, -0.7944837959679424),
921    (0.0586840934785355, 0.7944837959679424),
922    (0.0509980592623762, -0.8493676137325700),
923    (0.0509980592623762, 0.8493676137325700),
924    (0.0428358980222267, -0.8963211557660521),
925    (0.0428358980222267, 0.8963211557660521),
926    (0.0342738629130214, -0.9349060759377397),
927    (0.0342738629130214, 0.9349060759377397),
928    (0.0253920653092621, -0.9647622555875064),
929    (0.0253920653092621, 0.9647622555875064),
930    (0.0162743947309057, -0.9856115115452684),
931    (0.0162743947309057, 0.9856115115452684),
932    (0.0070186100094701, -0.9972638618494816),
933    (0.0070186100094701, 0.9972638618494816),
934];
935
936pub const GAUSS_LEGENDRE_COEFFS_32_HALF: &[(f64, f64)] = &[
937    (0.0965400885147278, 0.0483076656877383),
938    (0.0956387200792749, 0.1444719615827965),
939    (0.0938443990808046, 0.2392873622521371),
940    (0.0911738786957639, 0.3318686022821277),
941    (0.0876520930044038, 0.4213512761306353),
942    (0.0833119242269467, 0.5068999089322294),
943    (0.0781938957870703, 0.5877157572407623),
944    (0.0723457941088485, 0.6630442669302152),
945    (0.0658222227763618, 0.7321821187402897),
946    (0.0586840934785355, 0.7944837959679424),
947    (0.0509980592623762, 0.8493676137325700),
948    (0.0428358980222267, 0.8963211557660521),
949    (0.0342738629130214, 0.9349060759377397),
950    (0.0253920653092621, 0.9647622555875064),
951    (0.0162743947309057, 0.9856115115452684),
952    (0.0070186100094701, 0.9972638618494816),
953];
954
955#[cfg(test)]
956mod tests {
957    use crate::common::*;
958    use arrayvec::ArrayVec;
959
960    fn verify<const N: usize>(mut roots: ArrayVec<f64, N>, expected: &[f64]) {
961        assert_eq!(expected.len(), roots.len());
962        let epsilon = 1e-12;
963        roots.sort_by(|a, b| a.partial_cmp(b).unwrap());
964        for i in 0..expected.len() {
965            assert!((roots[i] - expected[i]).abs() < epsilon);
966        }
967    }
968
969    #[test]
970    fn test_solve_cubic() {
971        verify(solve_cubic(-5.0, 0.0, 0.0, 1.0), &[5.0f64.cbrt()]);
972        verify(solve_cubic(-5.0, -1.0, 0.0, 1.0), &[1.90416085913492]);
973        verify(solve_cubic(0.0, -1.0, 0.0, 1.0), &[-1.0, 0.0, 1.0]);
974        verify(solve_cubic(-2.0, -3.0, 0.0, 1.0), &[-1.0, 2.0]);
975        verify(solve_cubic(2.0, -3.0, 0.0, 1.0), &[-2.0, 1.0]);
976        verify(
977            solve_cubic(2.0 - 1e-12, 5.0, 4.0, 1.0),
978            &[
979                -1.9999999999989995,
980                -1.0000010000848456,
981                -0.9999989999161546,
982            ],
983        );
984        verify(solve_cubic(2.0 + 1e-12, 5.0, 4.0, 1.0), &[-2.0]);
985    }
986
987    #[test]
988    fn test_solve_quadratic() {
989        verify(
990            solve_quadratic(-5.0, 0.0, 1.0),
991            &[-(5.0f64.sqrt()), 5.0f64.sqrt()],
992        );
993        verify(solve_quadratic(5.0, 0.0, 1.0), &[]);
994        verify(solve_quadratic(5.0, 1.0, 0.0), &[-5.0]);
995        verify(solve_quadratic(1.0, 2.0, 1.0), &[-1.0]);
996    }
997
998    #[test]
999    fn test_solve_quartic() {
1000        // These test cases are taken from Orellana and De Michele paper (Table 1).
1001        fn test_with_roots(coeffs: [f64; 4], roots: &[f64], rel_err: f64) {
1002            // Note: in paper, coefficients are in decreasing order.
1003            let mut actual = solve_quartic(coeffs[3], coeffs[2], coeffs[1], coeffs[0], 1.0);
1004            actual.sort_by(f64::total_cmp);
1005            assert_eq!(actual.len(), roots.len());
1006            for (actual, expected) in actual.iter().zip(roots) {
1007                assert!(
1008                    (actual - expected).abs() < rel_err * expected.abs(),
1009                    "actual {:e}, expected {:e}, err {:e}",
1010                    actual,
1011                    expected,
1012                    actual - expected
1013                );
1014            }
1015        }
1016
1017        fn test_vieta_roots(x1: f64, x2: f64, x3: f64, x4: f64, roots: &[f64], rel_err: f64) {
1018            let a = -(x1 + x2 + x3 + x4);
1019            let b = x1 * (x2 + x3) + x2 * (x3 + x4) + x4 * (x1 + x3);
1020            let c = -x1 * x2 * (x3 + x4) - x3 * x4 * (x1 + x2);
1021            let d = x1 * x2 * x3 * x4;
1022            test_with_roots([a, b, c, d], roots, rel_err);
1023        }
1024
1025        fn test_vieta(x1: f64, x2: f64, x3: f64, x4: f64, rel_err: f64) {
1026            test_vieta_roots(x1, x2, x3, x4, &[x1, x2, x3, x4], rel_err);
1027        }
1028
1029        // case 1
1030        test_vieta(1., 1e3, 1e6, 1e9, 1e-16);
1031        // case 2
1032        test_vieta(2., 2.001, 2.002, 2.003, 1e-6);
1033        // case 3
1034        test_vieta(1e47, 1e49, 1e50, 1e53, 2e-16);
1035        // case 4
1036        test_vieta(-1., 1., 2., 1e14, 1e-16);
1037        // case 5
1038        test_vieta(-2e7, -1., 1., 1e7, 1e-16);
1039        // case 6
1040        test_with_roots(
1041            [-9000002.0, -9999981999998.0, 19999982e6, -2e13],
1042            &[-1e6, 1e7],
1043            1e-16,
1044        );
1045        // case 7
1046        test_with_roots(
1047            [2000011.0, 1010022000028.0, 11110056e6, 2828e10],
1048            &[-7., -4.],
1049            1e-16,
1050        );
1051        // case 8
1052        test_with_roots(
1053            [-100002011.0, 201101022001.0, -102200111000011.0, 11000011e8],
1054            &[11., 1e8],
1055            1e-16,
1056        );
1057        // cases 9-13 have no real roots
1058        // case 14
1059        test_vieta_roots(1000., 1000., 1000., 1000., &[1000., 1000.], 1e-16);
1060        // case 15
1061        test_vieta_roots(1e-15, 1000., 1000., 1000., &[1e-15, 1000., 1000.], 1e-15);
1062        // case 16 no real roots
1063        // case 17
1064        test_vieta(10000., 10001., 10010., 10100., 1e-6);
1065        // case 19
1066        test_vieta_roots(1., 1e30, 1e30, 1e44, &[1., 1e30, 1e44], 1e-16);
1067        // case 20
1068        // FAILS, error too big
1069        test_vieta(1., 1e7, 1e7, 1e14, 1e-7);
1070        // case 21 doesn't pick up double root
1071        // case 22
1072        test_vieta(1., 10., 1e152, 1e154, 3e-16);
1073        // case 23
1074        test_with_roots(
1075            [1., 1., 3. / 8., 1e-3],
1076            &[-0.497314148060048, -0.00268585193995149],
1077            2e-15,
1078        );
1079        // case 24
1080        const S: f64 = 1e30;
1081        test_with_roots(
1082            [-(1. + 1. / S), 1. / S - S * S, S * S + S, -S],
1083            &[-S, 1e-30, 1., S],
1084            2e-16,
1085        );
1086    }
1087
1088    #[test]
1089    fn test_solve_itp() {
1090        let f = |x: f64| x.powi(3) - x - 2.0;
1091        let x = solve_itp(f, 1., 2., 1e-12, 0, 0.2, f(1.), f(2.));
1092        assert!(f(x).abs() < 6e-12);
1093    }
1094
1095    #[test]
1096    fn test_inv_arclen() {
1097        use crate::{ParamCurve, ParamCurveArclen};
1098        let c = crate::CubicBez::new(
1099            (0.0, 0.0),
1100            (100.0 / 3.0, 0.0),
1101            (200.0 / 3.0, 100.0 / 3.0),
1102            (100.0, 100.0),
1103        );
1104        let target = 100.0;
1105        let _ = solve_itp(
1106            |t| c.subsegment(0.0..t).arclen(1e-9) - target,
1107            0.,
1108            1.,
1109            1e-6,
1110            1,
1111            0.2,
1112            -target,
1113            c.arclen(1e-9) - target,
1114        );
1115    }
1116}