kurbo/
common.rs

1// Copyright 2018 the Kurbo Authors
2// SPDX-License-Identifier: Apache-2.0 OR MIT
3
4//! Common mathematical operations
5
6#![allow(missing_docs)]
7
8#[cfg(not(feature = "std"))]
9mod sealed {
10    /// A [sealed trait](https://predr.ag/blog/definitive-guide-to-sealed-traits-in-rust/)
11    /// which stops [`super::FloatFuncs`] from being implemented outside kurbo. This could
12    /// be relaxed in the future if there is are good reasons to allow external impls.
13    /// The benefit from being sealed is that we can add methods without breaking downstream
14    /// implementations.
15    pub trait FloatFuncsSealed {}
16}
17
18use arrayvec::ArrayVec;
19
20/// Defines a trait that chooses between libstd or libm implementations of float methods.
21///
22/// Some methods will eventually became available in core by
23/// [`core_float_math`](https://github.com/rust-lang/rust/issues/137578)
24macro_rules! define_float_funcs {
25    ($(
26        fn $name:ident(self $(,$arg:ident: $arg_ty:ty)*) -> $ret:ty
27        => $lname:ident/$lfname:ident;
28    )+) => {
29
30        /// Since core doesn't depend upon libm, this provides libm implementations
31        /// of float functions which are typically provided by the std library, when
32        /// the `std` feature is not enabled.
33        ///
34        /// For documentation see the respective functions in the std library.
35        #[cfg(not(feature = "std"))]
36        pub trait FloatFuncs : Sized + sealed::FloatFuncsSealed {
37            /// For documentation see <https://doc.rust-lang.org/std/primitive.f64.html#method.signum>
38            ///
39            /// Special implementation, because libm doesn't have it.
40            fn signum(self) -> Self;
41
42            /// For documentation see <https://doc.rust-lang.org/std/primitive.f64.html#method.rem_euclid>
43            ///
44            /// Special implementation, because libm doesn't have it.
45            fn rem_euclid(self, rhs: Self) -> Self;
46
47            $(fn $name(self $(,$arg: $arg_ty)*) -> $ret;)+
48        }
49
50        #[cfg(not(feature = "std"))]
51        impl sealed::FloatFuncsSealed for f32 {}
52
53        #[cfg(not(feature = "std"))]
54        impl FloatFuncs for f32 {
55            #[inline]
56            fn signum(self) -> f32 {
57                if self.is_nan() {
58                    f32::NAN
59                } else {
60                    1.0_f32.copysign(self)
61                }
62            }
63
64            #[inline]
65            fn rem_euclid(self, rhs: Self) -> Self {
66                let r = self % rhs;
67                if r < 0.0 {
68                    r + rhs.abs()
69                } else {
70                    r
71                }
72            }
73
74            $(fn $name(self $(,$arg: $arg_ty)*) -> $ret {
75                #[cfg(feature = "libm")]
76                return libm::$lfname(self $(,$arg as _)*);
77
78                #[cfg(not(feature = "libm"))]
79                compile_error!("kurbo requires either the `std` or `libm` feature")
80            })+
81        }
82
83        #[cfg(not(feature = "std"))]
84        impl sealed::FloatFuncsSealed for f64 {}
85        #[cfg(not(feature = "std"))]
86        impl FloatFuncs for f64 {
87            #[inline]
88            fn signum(self) -> f64 {
89                if self.is_nan() {
90                    f64::NAN
91                } else {
92                    1.0_f64.copysign(self)
93                }
94            }
95
96            #[inline]
97            fn rem_euclid(self, rhs: Self) -> Self {
98                let r = self % rhs;
99                if r < 0.0 {
100                    r + rhs.abs()
101                } else {
102                    r
103                }
104            }
105
106            $(fn $name(self $(,$arg: $arg_ty)*) -> $ret {
107                #[cfg(feature = "libm")]
108                return libm::$lname(self $(,$arg as _)*);
109
110                #[cfg(not(feature = "libm"))]
111                compile_error!("kurbo requires either the `std` or `libm` feature")
112            })+
113        }
114    }
115}
116
117define_float_funcs! {
118    fn abs(self) -> Self => fabs/fabsf;
119    fn acos(self) -> Self => acos/acosf;
120    fn atan2(self, other: Self) -> Self => atan2/atan2f;
121    fn cbrt(self) -> Self => cbrt/cbrtf;
122    fn ceil(self) -> Self => ceil/ceilf;
123    fn cos(self) -> Self => cos/cosf;
124    fn copysign(self, sign: Self) -> Self => copysign/copysignf;
125    fn floor(self) -> Self => floor/floorf;
126    fn hypot(self, other: Self) -> Self => hypot/hypotf;
127    fn ln(self) -> Self => log/logf;
128    fn log2(self) -> Self => log2/log2f;
129    fn mul_add(self, a: Self, b: Self) -> Self => fma/fmaf;
130    fn powi(self, n: i32) -> Self => pow/powf;
131    fn powf(self, n: Self) -> Self => pow/powf;
132    fn round(self) -> Self => round/roundf;
133    fn sin(self) -> Self => sin/sinf;
134    fn sin_cos(self) -> (Self, Self) => sincos/sincosf;
135    fn sqrt(self) -> Self => sqrt/sqrtf;
136    fn tan(self) -> Self => tan/tanf;
137    fn trunc(self) -> Self => trunc/truncf;
138}
139
140/// Adds convenience methods to `f32` and `f64`.
141pub trait FloatExt<T> {
142    /// Rounds to the nearest integer away from zero,
143    /// unless the provided value is already an integer.
144    ///
145    /// It is to `ceil` what `trunc` is to `floor`.
146    ///
147    /// # Examples
148    ///
149    /// ```
150    /// use kurbo::common::FloatExt;
151    ///
152    /// let f = 3.7_f64;
153    /// let g = 3.0_f64;
154    /// let h = -3.7_f64;
155    /// let i = -5.1_f32;
156    ///
157    /// assert_eq!(f.expand(), 4.0);
158    /// assert_eq!(g.expand(), 3.0);
159    /// assert_eq!(h.expand(), -4.0);
160    /// assert_eq!(i.expand(), -6.0);
161    /// ```
162    fn expand(&self) -> T;
163}
164
165impl FloatExt<f64> for f64 {
166    #[inline]
167    fn expand(&self) -> f64 {
168        self.abs().ceil().copysign(*self)
169    }
170}
171
172impl FloatExt<f32> for f32 {
173    #[inline]
174    fn expand(&self) -> f32 {
175        self.abs().ceil().copysign(*self)
176    }
177}
178
179/// Find real roots of cubic equation.
180///
181/// The implementation is not (yet) fully robust, but it does handle the case
182/// where `c3` is zero (in that case, solving the quadratic equation).
183///
184/// See: <https://momentsingraphics.de/CubicRoots.html>
185///
186/// That implementation is in turn based on Jim Blinn's "How to Solve a Cubic
187/// Equation", which is masterful.
188///
189/// Return values of x for which c0 + c1 x + c2 x² + c3 x³ = 0.
190pub fn solve_cubic(c0: f64, c1: f64, c2: f64, c3: f64) -> ArrayVec<f64, 3> {
191    let mut result = ArrayVec::new();
192    let c3_recip = c3.recip();
193    const ONETHIRD: f64 = 1. / 3.;
194    let scaled_c2 = c2 * (ONETHIRD * c3_recip);
195    let scaled_c1 = c1 * (ONETHIRD * c3_recip);
196    let scaled_c0 = c0 * c3_recip;
197    if !(scaled_c0.is_finite() && scaled_c1.is_finite() && scaled_c2.is_finite()) {
198        // cubic coefficient is zero or nearly so.
199        return solve_quadratic(c0, c1, c2).iter().copied().collect();
200    }
201    let (c0, c1, c2) = (scaled_c0, scaled_c1, scaled_c2);
202    // (d0, d1, d2) is called "Delta" in article
203    let d0 = (-c2).mul_add(c2, c1);
204    let d1 = (-c1).mul_add(c2, c0);
205    let d2 = c2 * c0 - c1 * c1;
206    // d is called "Discriminant"
207    let d = 4.0 * d0 * d2 - d1 * d1;
208    // de is called "Depressed.x", Depressed.y = d0
209    let de = (-2.0 * c2).mul_add(d0, d1);
210    // TODO: handle the cases where these intermediate results overflow.
211    if d < 0.0 {
212        let sq = (-0.25 * d).sqrt();
213        let r = -0.5 * de;
214        let t1 = (r + sq).cbrt() + (r - sq).cbrt();
215        result.push(t1 - c2);
216    } else if d == 0.0 {
217        let t1 = (-d0).sqrt().copysign(de);
218        result.push(t1 - c2);
219        result.push(-2.0 * t1 - c2);
220    } else {
221        let th = d.sqrt().atan2(-de) * ONETHIRD;
222        // (th_cos, th_sin) is called "CubicRoot"
223        let (th_sin, th_cos) = th.sin_cos();
224        // (r0, r1, r2) is called "Root"
225        let r0 = th_cos;
226        let ss3 = th_sin * 3.0f64.sqrt();
227        let r1 = 0.5 * (-th_cos + ss3);
228        let r2 = 0.5 * (-th_cos - ss3);
229        let t = 2.0 * (-d0).sqrt();
230        result.push(t.mul_add(r0, -c2));
231        result.push(t.mul_add(r1, -c2));
232        result.push(t.mul_add(r2, -c2));
233    }
234    result
235}
236
237/// Find real roots of quadratic equation.
238///
239/// Return values of x for which c0 + c1 x + c2 x² = 0.
240///
241/// This function tries to be quite numerically robust. If the equation
242/// is nearly linear, it will return the root ignoring the quadratic term;
243/// the other root might be out of representable range. In the degenerate
244/// case where all coefficients are zero, so that all values of x satisfy
245/// the equation, a single `0.0` is returned.
246pub fn solve_quadratic(c0: f64, c1: f64, c2: f64) -> ArrayVec<f64, 2> {
247    let mut result = ArrayVec::new();
248    let sc0 = c0 * c2.recip();
249    let sc1 = c1 * c2.recip();
250    if !sc0.is_finite() || !sc1.is_finite() {
251        // c2 is zero or very small, treat as linear eqn
252        let root = -c0 / c1;
253        if root.is_finite() {
254            result.push(root);
255        } else if c0 == 0.0 && c1 == 0.0 {
256            // Degenerate case
257            result.push(0.0);
258        }
259        return result;
260    }
261    let arg = sc1 * sc1 - 4. * sc0;
262    let root1 = if !arg.is_finite() {
263        // Likely, calculation of sc1 * sc1 overflowed. Find one root
264        // using sc1 x + x² = 0, other root as sc0 / root1.
265        -sc1
266    } else {
267        if arg < 0.0 {
268            return result;
269        } else if arg == 0.0 {
270            result.push(-0.5 * sc1);
271            return result;
272        }
273        // See https://math.stackexchange.com/questions/866331
274        -0.5 * (sc1 + arg.sqrt().copysign(sc1))
275    };
276    let root2 = sc0 / root1;
277    if root2.is_finite() {
278        // Sort just to be friendly and make results deterministic.
279        if root2 > root1 {
280            result.push(root1);
281            result.push(root2);
282        } else {
283            result.push(root2);
284            result.push(root1);
285        }
286    } else {
287        result.push(root1);
288    }
289    result
290}
291
292/// Compute epsilon relative to coefficient.
293///
294/// A helper function from the Orellana and De Michele paper.
295fn eps_rel(raw: f64, a: f64) -> f64 {
296    if a == 0.0 {
297        raw.abs()
298    } else {
299        ((raw - a) / a).abs()
300    }
301}
302
303/// Find real roots of a quartic equation.
304///
305/// This is a fairly literal implementation of the method described in:
306/// Algorithm 1010: Boosting Efficiency in Solving Quartic Equations with
307/// No Compromise in Accuracy, Orellana and De Michele, ACM
308/// Transactions on Mathematical Software, Vol. 46, No. 2, May 2020.
309pub fn solve_quartic(c0: f64, c1: f64, c2: f64, c3: f64, c4: f64) -> ArrayVec<f64, 4> {
310    if c4 == 0.0 {
311        return solve_cubic(c0, c1, c2, c3).iter().copied().collect();
312    }
313    if c0 == 0.0 {
314        // Note: appends 0 root at end, doesn't sort. We might want to do that.
315        return solve_cubic(c1, c2, c3, c4)
316            .iter()
317            .copied()
318            .chain(Some(0.0))
319            .collect();
320    }
321    let a = c3 / c4;
322    let b = c2 / c4;
323    let c = c1 / c4;
324    let d = c0 / c4;
325    if let Some(result) = solve_quartic_inner(a, b, c, d, false) {
326        return result;
327    }
328    // Do polynomial rescaling
329    const K_Q: f64 = 7.16e76;
330    for rescale in [false, true] {
331        if let Some(result) = solve_quartic_inner(
332            a / K_Q,
333            b / K_Q.powi(2),
334            c / K_Q.powi(3),
335            d / K_Q.powi(4),
336            rescale,
337        ) {
338            return result.iter().map(|x| x * K_Q).collect();
339        }
340    }
341    // Overflow happened, just return no roots.
342    //println!("overflow, no roots returned");
343    ArrayVec::default()
344}
345
346fn solve_quartic_inner(a: f64, b: f64, c: f64, d: f64, rescale: bool) -> Option<ArrayVec<f64, 4>> {
347    factor_quartic_inner(a, b, c, d, rescale).map(|quadratics| {
348        quadratics
349            .iter()
350            .flat_map(|(a, b)| solve_quadratic(*b, *a, 1.0))
351            .collect()
352    })
353}
354
355/// Factor a quartic into two quadratics.
356///
357/// Attempt to factor a quartic equation into two quadratic equations. Returns `None` either if there
358/// is overflow (in which case rescaling might succeed) or the factorization would result in
359/// complex coefficients.
360///
361/// Discussion question: distinguish the two cases in return value?
362pub fn factor_quartic_inner(
363    a: f64,
364    b: f64,
365    c: f64,
366    d: f64,
367    rescale: bool,
368) -> Option<ArrayVec<(f64, f64), 2>> {
369    let calc_eps_q = |a1, b1, a2, b2| {
370        let eps_a = eps_rel(a1 + a2, a);
371        let eps_b = eps_rel(b1 + a1 * a2 + b2, b);
372        let eps_c = eps_rel(b1 * a2 + a1 * b2, c);
373        eps_a + eps_b + eps_c
374    };
375    let calc_eps_t = |a1, b1, a2, b2| calc_eps_q(a1, b1, a2, b2) + eps_rel(b1 * b2, d);
376    let disc = 9. * a * a - 24. * b;
377    let s = if disc >= 0.0 {
378        -2. * b / (3. * a + disc.sqrt().copysign(a))
379    } else {
380        -0.25 * a
381    };
382    let a_prime = a + 4. * s;
383    let b_prime = b + 3. * s * (a + 2. * s);
384    let c_prime = c + s * (2. * b + s * (3. * a + 4. * s));
385    let d_prime = d + s * (c + s * (b + s * (a + s)));
386    let g_prime;
387    let h_prime;
388    const K_C: f64 = 3.49e102;
389    if rescale {
390        let a_prime_s = a_prime / K_C;
391        let b_prime_s = b_prime / K_C;
392        let c_prime_s = c_prime / K_C;
393        let d_prime_s = d_prime / K_C;
394        g_prime = a_prime_s * c_prime_s - (4. / K_C) * d_prime_s - (1. / 3.) * b_prime_s.powi(2);
395        h_prime = (a_prime_s * c_prime_s + (8. / K_C) * d_prime_s - (2. / 9.) * b_prime_s.powi(2))
396            * (1. / 3.)
397            * b_prime_s
398            - c_prime_s * (c_prime_s / K_C)
399            - a_prime_s.powi(2) * d_prime_s;
400    } else {
401        g_prime = a_prime * c_prime - 4. * d_prime - (1. / 3.) * b_prime.powi(2);
402        h_prime =
403            (a_prime * c_prime + 8. * d_prime - (2. / 9.) * b_prime.powi(2)) * (1. / 3.) * b_prime
404                - c_prime.powi(2)
405                - a_prime.powi(2) * d_prime;
406    }
407    if !(g_prime.is_finite() && h_prime.is_finite()) {
408        return None;
409    }
410    let phi = depressed_cubic_dominant(g_prime, h_prime);
411    let phi = if rescale { phi * K_C } else { phi };
412    let l_1 = a * 0.5;
413    let l_3 = (1. / 6.) * b + 0.5 * phi;
414    let delt_2 = c - a * l_3;
415    let d_2_cand_1 = (2. / 3.) * b - phi - l_1 * l_1;
416    let l_2_cand_1 = 0.5 * delt_2 / d_2_cand_1;
417    let l_2_cand_2 = 2. * (d - l_3 * l_3) / delt_2;
418    let d_2_cand_2 = 0.5 * delt_2 / l_2_cand_2;
419    let d_2_cand_3 = d_2_cand_1;
420    let l_2_cand_3 = l_2_cand_2;
421    let mut d_2_best = 0.0;
422    let mut l_2_best = 0.0;
423    let mut eps_l_best = 0.0;
424    for (i, (d_2, l_2)) in [
425        (d_2_cand_1, l_2_cand_1),
426        (d_2_cand_2, l_2_cand_2),
427        (d_2_cand_3, l_2_cand_3),
428    ]
429    .iter()
430    .enumerate()
431    {
432        let eps_0 = eps_rel(d_2 + l_1 * l_1 + 2. * l_3, b);
433        let eps_1 = eps_rel(2. * (d_2 * l_2 + l_1 * l_3), c);
434        let eps_2 = eps_rel(d_2 * l_2 * l_2 + l_3 * l_3, d);
435        let eps_l = eps_0 + eps_1 + eps_2;
436        if i == 0 || eps_l < eps_l_best {
437            d_2_best = *d_2;
438            l_2_best = *l_2;
439            eps_l_best = eps_l;
440        }
441    }
442    let d_2 = d_2_best;
443    let l_2 = l_2_best;
444    let mut alpha_1;
445    let mut beta_1;
446    let mut alpha_2;
447    let mut beta_2;
448    //println!("phi = {}, d_2 = {}", phi, d_2);
449    if d_2 < 0.0 {
450        let sq = (-d_2).sqrt();
451        alpha_1 = l_1 + sq;
452        beta_1 = l_3 + sq * l_2;
453        alpha_2 = l_1 - sq;
454        beta_2 = l_3 - sq * l_2;
455        if beta_2.abs() < beta_1.abs() {
456            beta_2 = d / beta_1;
457        } else if beta_2.abs() > beta_1.abs() {
458            beta_1 = d / beta_2;
459        }
460        let cands;
461        if alpha_1.abs() != alpha_2.abs() {
462            if alpha_1.abs() < alpha_2.abs() {
463                let a1_cand_1 = (c - beta_1 * alpha_2) / beta_2;
464                let a1_cand_2 = (b - beta_2 - beta_1) / alpha_2;
465                let a1_cand_3 = a - alpha_2;
466                // Note: cand 3 is first because it is infallible, simplifying logic
467                cands = [
468                    (a1_cand_3, alpha_2),
469                    (a1_cand_1, alpha_2),
470                    (a1_cand_2, alpha_2),
471                ];
472            } else {
473                let a2_cand_1 = (c - alpha_1 * beta_2) / beta_1;
474                let a2_cand_2 = (b - beta_2 - beta_1) / alpha_1;
475                let a2_cand_3 = a - alpha_1;
476                cands = [
477                    (alpha_1, a2_cand_3),
478                    (alpha_1, a2_cand_1),
479                    (alpha_1, a2_cand_2),
480                ];
481            }
482            let mut eps_q_best = 0.0;
483            for (i, (a1, a2)) in cands.iter().enumerate() {
484                if a1.is_finite() && a2.is_finite() {
485                    let eps_q = calc_eps_q(*a1, beta_1, *a2, beta_2);
486                    if i == 0 || eps_q < eps_q_best {
487                        alpha_1 = *a1;
488                        alpha_2 = *a2;
489                        eps_q_best = eps_q;
490                    }
491                }
492            }
493        }
494    } else if d_2 == 0.0 {
495        let d_3 = d - l_3 * l_3;
496        alpha_1 = l_1;
497        beta_1 = l_3 + (-d_3).sqrt();
498        alpha_2 = l_1;
499        beta_2 = l_3 - (-d_3).sqrt();
500        if beta_1.abs() > beta_2.abs() {
501            beta_2 = d / beta_1;
502        } else if beta_2.abs() > beta_1.abs() {
503            beta_1 = d / beta_2;
504        }
505        // TODO: handle case d_2 is very small?
506    } else {
507        // This case means no real roots; in the most general case we might want
508        // to factor into quadratic equations with complex coefficients.
509        return None;
510    }
511    // Newton-Raphson iteration on alpha/beta coeff's.
512    let mut eps_t = calc_eps_t(alpha_1, beta_1, alpha_2, beta_2);
513    for _ in 0..8 {
514        //println!("a1 {} b1 {} a2 {} b2 {}", alpha_1, beta_1, alpha_2, beta_2);
515        //println!("eps_t = {:e}", eps_t);
516        if eps_t == 0.0 {
517            break;
518        }
519        let f_0 = beta_1 * beta_2 - d;
520        let f_1 = beta_1 * alpha_2 + alpha_1 * beta_2 - c;
521        let f_2 = beta_1 + alpha_1 * alpha_2 + beta_2 - b;
522        let f_3 = alpha_1 + alpha_2 - a;
523        let c_1 = alpha_1 - alpha_2;
524        let det_j = beta_1 * beta_1 - beta_1 * (alpha_2 * c_1 + 2. * beta_2)
525            + beta_2 * (alpha_1 * c_1 + beta_2);
526        if det_j == 0.0 {
527            break;
528        }
529        let inv = det_j.recip();
530        let c_2 = beta_2 - beta_1;
531        let c_3 = beta_1 * alpha_2 - alpha_1 * beta_2;
532        let dz_0 = c_1 * f_0 + c_2 * f_1 + c_3 * f_2 - (beta_1 * c_2 + alpha_1 * c_3) * f_3;
533        let dz_1 = (alpha_1 * c_1 + c_2) * f_0
534            - beta_1 * c_1 * f_1
535            - beta_1 * c_2 * f_2
536            - beta_1 * c_3 * f_3;
537        let dz_2 = -c_1 * f_0 - c_2 * f_1 - c_3 * f_2 + (alpha_2 * c_3 + beta_2 * c_2) * f_3;
538        let dz_3 = -(alpha_2 * c_1 + c_2) * f_0
539            + beta_2 * c_1 * f_1
540            + beta_2 * c_2 * f_2
541            + beta_2 * c_3 * f_3;
542        let a1 = alpha_1 - inv * dz_0;
543        let b1 = beta_1 - inv * dz_1;
544        let a2 = alpha_2 - inv * dz_2;
545        let b2 = beta_2 - inv * dz_3;
546        let new_eps_t = calc_eps_t(a1, b1, a2, b2);
547        // We break if the new eps is equal, paper keeps going
548        if new_eps_t < eps_t {
549            alpha_1 = a1;
550            beta_1 = b1;
551            alpha_2 = a2;
552            beta_2 = b2;
553            eps_t = new_eps_t;
554        } else {
555            //println!("new_eps_t got worse: {:e}", new_eps_t);
556            break;
557        }
558    }
559    Some([(alpha_1, beta_1), (alpha_2, beta_2)].into())
560}
561
562/// Dominant root of depressed cubic x^3 + gx + h = 0.
563///
564/// Section 2.2 of Orellana and De Michele.
565// Note: some of the techniques in here might be useful to improve the
566// cubic solver, and vice versa.
567fn depressed_cubic_dominant(g: f64, h: f64) -> f64 {
568    let q = (-1. / 3.) * g;
569    let r = 0.5 * h;
570    let phi_0;
571    let k = if q.abs() < 1e102 && r.abs() < 1e154 {
572        None
573    } else if q.abs() < r.abs() {
574        Some(1. - q * (q / r).powi(2))
575    } else {
576        Some(q.signum() * ((r / q).powi(2) / q - 1.0))
577    };
578    if k.is_some() && r == 0.0 {
579        if g > 0.0 {
580            phi_0 = 0.0;
581        } else {
582            phi_0 = (-g).sqrt();
583        }
584    } else if k.map(|k| k < 0.0).unwrap_or_else(|| r * r < q.powi(3)) {
585        let t = if k.is_some() {
586            r / q / q.sqrt()
587        } else {
588            r / q.powi(3).sqrt()
589        };
590        phi_0 = -2. * q.sqrt() * (t.abs().acos() * (1. / 3.)).cos().copysign(t);
591    } else {
592        let a = if let Some(k) = k {
593            if q.abs() < r.abs() {
594                -r * (1. + k.sqrt())
595            } else {
596                -r - (q.abs().sqrt() * q * k.sqrt()).copysign(r)
597            }
598        } else {
599            -r - (r * r - q.powi(3)).sqrt().copysign(r)
600        }
601        .cbrt();
602        let b = if a == 0.0 { 0.0 } else { q / a };
603        phi_0 = a + b;
604    }
605    // Refine with Newton-Raphson iteration
606    let mut x = phi_0;
607    let mut f = (x * x + g) * x + h;
608    //println!("g = {:e}, h = {:e}, x = {:e}, f = {:e}", g, h, x, f);
609    const EPS_M: f64 = 2.22045e-16;
610    if f.abs() < EPS_M * x.powi(3).max(g * x).max(h) {
611        return x;
612    }
613    for _ in 0..8 {
614        let delt_f = 3. * x * x + g;
615        if delt_f == 0.0 {
616            break;
617        }
618        let new_x = x - f / delt_f;
619        let new_f = (new_x * new_x + g) * new_x + h;
620        //println!("delt_f = {:e}, new_f = {:e}", delt_f, new_f);
621        if new_f == 0.0 {
622            return new_x;
623        }
624        if new_f.abs() >= f.abs() {
625            break;
626        }
627        x = new_x;
628        f = new_f;
629    }
630    x
631}
632
633/// Solve an arbitrary function for a zero-crossing.
634///
635/// This uses the [ITP method], as described in the paper
636/// [An Enhancement of the Bisection Method Average Performance Preserving Minmax Optimality].
637///
638/// The values of `ya` and `yb` are given as arguments rather than
639/// computed from `f`, as the values may already be known, or they may
640/// be less expensive to compute as special cases.
641///
642/// It is assumed that `ya < 0.0` and `yb > 0.0`, otherwise unexpected
643/// results may occur.
644///
645/// The value of `epsilon` must be larger than 2^-63 times `b - a`,
646/// otherwise integer overflow may occur. The `a` and `b` parameters
647/// represent the lower and upper bounds of the bracket searched for a
648/// solution.
649///
650/// The ITP method has tuning parameters. This implementation hardwires
651/// k2 to 2, both because it avoids an expensive floating point
652/// exponentiation, and because this value has been tested to work well
653/// with curve fitting problems.
654///
655/// The `n0` parameter controls the relative impact of the bisection and
656/// secant components. When it is 0, the number of iterations is
657/// guaranteed to be no more than the number required by bisection (thus,
658/// this method is strictly superior to bisection). However, when the
659/// function is smooth, a value of 1 gives the secant method more of a
660/// chance to engage, so the average number of iterations is likely
661/// lower, though there can be one more iteration than bisection in the
662/// worst case.
663///
664/// The `k1` parameter is harder to characterize, and interested users
665/// are referred to the paper, as well as encouraged to do empirical
666/// testing. To match the paper, a value of `0.2 / (b - a)` is
667/// suggested, and this is confirmed to give good results.
668///
669/// When the function is monotonic, the returned result is guaranteed to
670/// be within `epsilon` of the zero crossing. For more detailed analysis,
671/// again see the paper.
672///
673/// [ITP method]: https://en.wikipedia.org/wiki/ITP_Method
674/// [An Enhancement of the Bisection Method Average Performance Preserving Minmax Optimality]: https://dl.acm.org/doi/10.1145/3423597
675#[allow(clippy::too_many_arguments)]
676pub fn solve_itp(
677    mut f: impl FnMut(f64) -> f64,
678    mut a: f64,
679    mut b: f64,
680    epsilon: f64,
681    n0: usize,
682    k1: f64,
683    mut ya: f64,
684    mut yb: f64,
685) -> f64 {
686    let n1_2 = (((b - a) / epsilon).log2().ceil() - 1.0).max(0.0) as usize;
687    let nmax = n0 + n1_2;
688    let mut scaled_epsilon = epsilon * (1u64 << nmax) as f64;
689    while b - a > 2.0 * epsilon {
690        let x1_2 = 0.5 * (a + b);
691        let r = scaled_epsilon - 0.5 * (b - a);
692        let xf = (yb * a - ya * b) / (yb - ya);
693        let sigma = x1_2 - xf;
694        // This has k2 = 2 hardwired for efficiency.
695        let delta = k1 * (b - a).powi(2);
696        let xt = if delta <= (x1_2 - xf).abs() {
697            xf + delta.copysign(sigma)
698        } else {
699            x1_2
700        };
701        let xitp = if (xt - x1_2).abs() <= r {
702            xt
703        } else {
704            x1_2 - r.copysign(sigma)
705        };
706        let yitp = f(xitp);
707        if yitp > 0.0 {
708            b = xitp;
709            yb = yitp;
710        } else if yitp < 0.0 {
711            a = xitp;
712            ya = yitp;
713        } else {
714            return xitp;
715        }
716        scaled_epsilon *= 0.5;
717    }
718    0.5 * (a + b)
719}
720
721/// A variant ITP solver that allows fallible functions.
722///
723/// Another difference: it returns the bracket that contains the root,
724/// which may be important if the function has a discontinuity.
725#[allow(clippy::too_many_arguments)]
726pub(crate) fn solve_itp_fallible<E>(
727    mut f: impl FnMut(f64) -> Result<f64, E>,
728    mut a: f64,
729    mut b: f64,
730    epsilon: f64,
731    n0: usize,
732    k1: f64,
733    mut ya: f64,
734    mut yb: f64,
735) -> Result<(f64, f64), E> {
736    let n1_2 = (((b - a) / epsilon).log2().ceil() - 1.0).max(0.0) as usize;
737    let nmax = n0 + n1_2;
738    let mut scaled_epsilon = epsilon * (1u64 << nmax) as f64;
739    while b - a > 2.0 * epsilon {
740        let x1_2 = 0.5 * (a + b);
741        let r = scaled_epsilon - 0.5 * (b - a);
742        let xf = (yb * a - ya * b) / (yb - ya);
743        let sigma = x1_2 - xf;
744        // This has k2 = 2 hardwired for efficiency.
745        let delta = k1 * (b - a).powi(2);
746        let xt = if delta <= (x1_2 - xf).abs() {
747            xf + delta.copysign(sigma)
748        } else {
749            x1_2
750        };
751        let xitp = if (xt - x1_2).abs() <= r {
752            xt
753        } else {
754            x1_2 - r.copysign(sigma)
755        };
756        let yitp = f(xitp)?;
757        if yitp > 0.0 {
758            b = xitp;
759            yb = yitp;
760        } else if yitp < 0.0 {
761            a = xitp;
762            ya = yitp;
763        } else {
764            return Ok((xitp, xitp));
765        }
766        scaled_epsilon *= 0.5;
767    }
768    Ok((a, b))
769}
770
771// Tables of Legendre-Gauss quadrature coefficients, adapted from:
772// <https://pomax.github.io/bezierinfo/legendre-gauss.html>
773
774pub const GAUSS_LEGENDRE_COEFFS_3: &[(f64, f64)] = &[
775    (0.8888888888888888, 0.0000000000000000),
776    (0.5555555555555556, -0.7745966692414834),
777    (0.5555555555555556, 0.7745966692414834),
778];
779
780pub const GAUSS_LEGENDRE_COEFFS_4: &[(f64, f64)] = &[
781    (0.6521451548625461, -0.3399810435848563),
782    (0.6521451548625461, 0.3399810435848563),
783    (0.3478548451374538, -0.8611363115940526),
784    (0.3478548451374538, 0.8611363115940526),
785];
786
787pub const GAUSS_LEGENDRE_COEFFS_5: &[(f64, f64)] = &[
788    (0.5688888888888889, 0.0000000000000000),
789    (0.4786286704993665, -0.5384693101056831),
790    (0.4786286704993665, 0.5384693101056831),
791    (0.2369268850561891, -0.9061798459386640),
792    (0.2369268850561891, 0.9061798459386640),
793];
794
795pub const GAUSS_LEGENDRE_COEFFS_6: &[(f64, f64)] = &[
796    (0.3607615730481386, 0.6612093864662645),
797    (0.3607615730481386, -0.6612093864662645),
798    (0.4679139345726910, -0.2386191860831969),
799    (0.4679139345726910, 0.2386191860831969),
800    (0.1713244923791704, -0.9324695142031521),
801    (0.1713244923791704, 0.9324695142031521),
802];
803
804pub const GAUSS_LEGENDRE_COEFFS_7: &[(f64, f64)] = &[
805    (0.4179591836734694, 0.0000000000000000),
806    (0.3818300505051189, 0.4058451513773972),
807    (0.3818300505051189, -0.4058451513773972),
808    (0.2797053914892766, -0.7415311855993945),
809    (0.2797053914892766, 0.7415311855993945),
810    (0.1294849661688697, -0.9491079123427585),
811    (0.1294849661688697, 0.9491079123427585),
812];
813
814pub const GAUSS_LEGENDRE_COEFFS_8: &[(f64, f64)] = &[
815    (0.3626837833783620, -0.1834346424956498),
816    (0.3626837833783620, 0.1834346424956498),
817    (0.3137066458778873, -0.5255324099163290),
818    (0.3137066458778873, 0.5255324099163290),
819    (0.2223810344533745, -0.7966664774136267),
820    (0.2223810344533745, 0.7966664774136267),
821    (0.1012285362903763, -0.9602898564975363),
822    (0.1012285362903763, 0.9602898564975363),
823];
824
825pub const GAUSS_LEGENDRE_COEFFS_8_HALF: &[(f64, f64)] = &[
826    (0.3626837833783620, 0.1834346424956498),
827    (0.3137066458778873, 0.5255324099163290),
828    (0.2223810344533745, 0.7966664774136267),
829    (0.1012285362903763, 0.9602898564975363),
830];
831
832pub const GAUSS_LEGENDRE_COEFFS_9: &[(f64, f64)] = &[
833    (0.3302393550012598, 0.0000000000000000),
834    (0.1806481606948574, -0.8360311073266358),
835    (0.1806481606948574, 0.8360311073266358),
836    (0.0812743883615744, -0.9681602395076261),
837    (0.0812743883615744, 0.9681602395076261),
838    (0.3123470770400029, -0.3242534234038089),
839    (0.3123470770400029, 0.3242534234038089),
840    (0.2606106964029354, -0.6133714327005904),
841    (0.2606106964029354, 0.6133714327005904),
842];
843
844pub const GAUSS_LEGENDRE_COEFFS_11: &[(f64, f64)] = &[
845    (0.2729250867779006, 0.0000000000000000),
846    (0.2628045445102467, -0.2695431559523450),
847    (0.2628045445102467, 0.2695431559523450),
848    (0.2331937645919905, -0.5190961292068118),
849    (0.2331937645919905, 0.5190961292068118),
850    (0.1862902109277343, -0.7301520055740494),
851    (0.1862902109277343, 0.7301520055740494),
852    (0.1255803694649046, -0.8870625997680953),
853    (0.1255803694649046, 0.8870625997680953),
854    (0.0556685671161737, -0.9782286581460570),
855    (0.0556685671161737, 0.9782286581460570),
856];
857
858pub const GAUSS_LEGENDRE_COEFFS_16: &[(f64, f64)] = &[
859    (0.1894506104550685, -0.0950125098376374),
860    (0.1894506104550685, 0.0950125098376374),
861    (0.1826034150449236, -0.2816035507792589),
862    (0.1826034150449236, 0.2816035507792589),
863    (0.1691565193950025, -0.4580167776572274),
864    (0.1691565193950025, 0.4580167776572274),
865    (0.1495959888165767, -0.6178762444026438),
866    (0.1495959888165767, 0.6178762444026438),
867    (0.1246289712555339, -0.7554044083550030),
868    (0.1246289712555339, 0.7554044083550030),
869    (0.0951585116824928, -0.8656312023878318),
870    (0.0951585116824928, 0.8656312023878318),
871    (0.0622535239386479, -0.9445750230732326),
872    (0.0622535239386479, 0.9445750230732326),
873    (0.0271524594117541, -0.9894009349916499),
874    (0.0271524594117541, 0.9894009349916499),
875];
876
877// Just the positive x_i values.
878pub const GAUSS_LEGENDRE_COEFFS_16_HALF: &[(f64, f64)] = &[
879    (0.1894506104550685, 0.0950125098376374),
880    (0.1826034150449236, 0.2816035507792589),
881    (0.1691565193950025, 0.4580167776572274),
882    (0.1495959888165767, 0.6178762444026438),
883    (0.1246289712555339, 0.7554044083550030),
884    (0.0951585116824928, 0.8656312023878318),
885    (0.0622535239386479, 0.9445750230732326),
886    (0.0271524594117541, 0.9894009349916499),
887];
888
889pub const GAUSS_LEGENDRE_COEFFS_24: &[(f64, f64)] = &[
890    (0.1279381953467522, -0.0640568928626056),
891    (0.1279381953467522, 0.0640568928626056),
892    (0.1258374563468283, -0.1911188674736163),
893    (0.1258374563468283, 0.1911188674736163),
894    (0.1216704729278034, -0.3150426796961634),
895    (0.1216704729278034, 0.3150426796961634),
896    (0.1155056680537256, -0.4337935076260451),
897    (0.1155056680537256, 0.4337935076260451),
898    (0.1074442701159656, -0.5454214713888396),
899    (0.1074442701159656, 0.5454214713888396),
900    (0.0976186521041139, -0.6480936519369755),
901    (0.0976186521041139, 0.6480936519369755),
902    (0.0861901615319533, -0.7401241915785544),
903    (0.0861901615319533, 0.7401241915785544),
904    (0.0733464814110803, -0.8200019859739029),
905    (0.0733464814110803, 0.8200019859739029),
906    (0.0592985849154368, -0.8864155270044011),
907    (0.0592985849154368, 0.8864155270044011),
908    (0.0442774388174198, -0.9382745520027328),
909    (0.0442774388174198, 0.9382745520027328),
910    (0.0285313886289337, -0.9747285559713095),
911    (0.0285313886289337, 0.9747285559713095),
912    (0.0123412297999872, -0.9951872199970213),
913    (0.0123412297999872, 0.9951872199970213),
914];
915
916pub const GAUSS_LEGENDRE_COEFFS_24_HALF: &[(f64, f64)] = &[
917    (0.1279381953467522, 0.0640568928626056),
918    (0.1258374563468283, 0.1911188674736163),
919    (0.1216704729278034, 0.3150426796961634),
920    (0.1155056680537256, 0.4337935076260451),
921    (0.1074442701159656, 0.5454214713888396),
922    (0.0976186521041139, 0.6480936519369755),
923    (0.0861901615319533, 0.7401241915785544),
924    (0.0733464814110803, 0.8200019859739029),
925    (0.0592985849154368, 0.8864155270044011),
926    (0.0442774388174198, 0.9382745520027328),
927    (0.0285313886289337, 0.9747285559713095),
928    (0.0123412297999872, 0.9951872199970213),
929];
930
931pub const GAUSS_LEGENDRE_COEFFS_32: &[(f64, f64)] = &[
932    (0.0965400885147278, -0.0483076656877383),
933    (0.0965400885147278, 0.0483076656877383),
934    (0.0956387200792749, -0.1444719615827965),
935    (0.0956387200792749, 0.1444719615827965),
936    (0.0938443990808046, -0.2392873622521371),
937    (0.0938443990808046, 0.2392873622521371),
938    (0.0911738786957639, -0.3318686022821277),
939    (0.0911738786957639, 0.3318686022821277),
940    (0.0876520930044038, -0.4213512761306353),
941    (0.0876520930044038, 0.4213512761306353),
942    (0.0833119242269467, -0.5068999089322294),
943    (0.0833119242269467, 0.5068999089322294),
944    (0.0781938957870703, -0.5877157572407623),
945    (0.0781938957870703, 0.5877157572407623),
946    (0.0723457941088485, -0.6630442669302152),
947    (0.0723457941088485, 0.6630442669302152),
948    (0.0658222227763618, -0.7321821187402897),
949    (0.0658222227763618, 0.7321821187402897),
950    (0.0586840934785355, -0.7944837959679424),
951    (0.0586840934785355, 0.7944837959679424),
952    (0.0509980592623762, -0.8493676137325700),
953    (0.0509980592623762, 0.8493676137325700),
954    (0.0428358980222267, -0.8963211557660521),
955    (0.0428358980222267, 0.8963211557660521),
956    (0.0342738629130214, -0.9349060759377397),
957    (0.0342738629130214, 0.9349060759377397),
958    (0.0253920653092621, -0.9647622555875064),
959    (0.0253920653092621, 0.9647622555875064),
960    (0.0162743947309057, -0.9856115115452684),
961    (0.0162743947309057, 0.9856115115452684),
962    (0.0070186100094701, -0.9972638618494816),
963    (0.0070186100094701, 0.9972638618494816),
964];
965
966pub const GAUSS_LEGENDRE_COEFFS_32_HALF: &[(f64, f64)] = &[
967    (0.0965400885147278, 0.0483076656877383),
968    (0.0956387200792749, 0.1444719615827965),
969    (0.0938443990808046, 0.2392873622521371),
970    (0.0911738786957639, 0.3318686022821277),
971    (0.0876520930044038, 0.4213512761306353),
972    (0.0833119242269467, 0.5068999089322294),
973    (0.0781938957870703, 0.5877157572407623),
974    (0.0723457941088485, 0.6630442669302152),
975    (0.0658222227763618, 0.7321821187402897),
976    (0.0586840934785355, 0.7944837959679424),
977    (0.0509980592623762, 0.8493676137325700),
978    (0.0428358980222267, 0.8963211557660521),
979    (0.0342738629130214, 0.9349060759377397),
980    (0.0253920653092621, 0.9647622555875064),
981    (0.0162743947309057, 0.9856115115452684),
982    (0.0070186100094701, 0.9972638618494816),
983];
984
985#[cfg(test)]
986mod tests {
987    use crate::common::*;
988    use arrayvec::ArrayVec;
989
990    fn verify<const N: usize>(mut roots: ArrayVec<f64, N>, expected: &[f64]) {
991        assert_eq!(expected.len(), roots.len());
992        let epsilon = 1e-12;
993        roots.sort_by(|a, b| a.partial_cmp(b).unwrap());
994        for i in 0..expected.len() {
995            assert!((roots[i] - expected[i]).abs() < epsilon);
996        }
997    }
998
999    #[test]
1000    fn test_solve_cubic() {
1001        verify(solve_cubic(-5.0, 0.0, 0.0, 1.0), &[5.0f64.cbrt()]);
1002        verify(solve_cubic(-5.0, -1.0, 0.0, 1.0), &[1.90416085913492]);
1003        verify(solve_cubic(0.0, -1.0, 0.0, 1.0), &[-1.0, 0.0, 1.0]);
1004        verify(solve_cubic(-2.0, -3.0, 0.0, 1.0), &[-1.0, 2.0]);
1005        verify(solve_cubic(2.0, -3.0, 0.0, 1.0), &[-2.0, 1.0]);
1006        verify(
1007            solve_cubic(2.0 - 1e-12, 5.0, 4.0, 1.0),
1008            &[
1009                -1.9999999999989995,
1010                -1.0000010000848456,
1011                -0.9999989999161546,
1012            ],
1013        );
1014        verify(solve_cubic(2.0 + 1e-12, 5.0, 4.0, 1.0), &[-2.0]);
1015    }
1016
1017    #[test]
1018    fn test_solve_quadratic() {
1019        verify(
1020            solve_quadratic(-5.0, 0.0, 1.0),
1021            &[-(5.0f64.sqrt()), 5.0f64.sqrt()],
1022        );
1023        verify(solve_quadratic(5.0, 0.0, 1.0), &[]);
1024        verify(solve_quadratic(5.0, 1.0, 0.0), &[-5.0]);
1025        verify(solve_quadratic(1.0, 2.0, 1.0), &[-1.0]);
1026    }
1027
1028    #[test]
1029    fn test_solve_quartic() {
1030        // These test cases are taken from Orellana and De Michele paper (Table 1).
1031        fn test_with_roots(coeffs: [f64; 4], roots: &[f64], rel_err: f64) {
1032            // Note: in paper, coefficients are in decreasing order.
1033            let mut actual = solve_quartic(coeffs[3], coeffs[2], coeffs[1], coeffs[0], 1.0);
1034            actual.sort_by(f64::total_cmp);
1035            assert_eq!(actual.len(), roots.len());
1036            for (actual, expected) in actual.iter().zip(roots) {
1037                assert!(
1038                    (actual - expected).abs() < rel_err * expected.abs(),
1039                    "actual {:e}, expected {:e}, err {:e}",
1040                    actual,
1041                    expected,
1042                    actual - expected
1043                );
1044            }
1045        }
1046
1047        fn test_vieta_roots(x1: f64, x2: f64, x3: f64, x4: f64, roots: &[f64], rel_err: f64) {
1048            let a = -(x1 + x2 + x3 + x4);
1049            let b = x1 * (x2 + x3) + x2 * (x3 + x4) + x4 * (x1 + x3);
1050            let c = -x1 * x2 * (x3 + x4) - x3 * x4 * (x1 + x2);
1051            let d = x1 * x2 * x3 * x4;
1052            test_with_roots([a, b, c, d], roots, rel_err);
1053        }
1054
1055        fn test_vieta(x1: f64, x2: f64, x3: f64, x4: f64, rel_err: f64) {
1056            test_vieta_roots(x1, x2, x3, x4, &[x1, x2, x3, x4], rel_err);
1057        }
1058
1059        // case 1
1060        test_vieta(1., 1e3, 1e6, 1e9, 1e-16);
1061        // case 2
1062        test_vieta(2., 2.001, 2.002, 2.003, 1e-6);
1063        // case 3
1064        test_vieta(1e47, 1e49, 1e50, 1e53, 2e-16);
1065        // case 4
1066        test_vieta(-1., 1., 2., 1e14, 1e-16);
1067        // case 5
1068        test_vieta(-2e7, -1., 1., 1e7, 1e-16);
1069        // case 6
1070        test_with_roots(
1071            [-9000002.0, -9999981999998.0, 19999982e6, -2e13],
1072            &[-1e6, 1e7],
1073            1e-16,
1074        );
1075        // case 7
1076        test_with_roots(
1077            [2000011.0, 1010022000028.0, 11110056e6, 2828e10],
1078            &[-7., -4.],
1079            1e-16,
1080        );
1081        // case 8
1082        test_with_roots(
1083            [-100002011.0, 201101022001.0, -102200111000011.0, 11000011e8],
1084            &[11., 1e8],
1085            1e-16,
1086        );
1087        // cases 9-13 have no real roots
1088        // case 14
1089        test_vieta_roots(1000., 1000., 1000., 1000., &[1000., 1000.], 1e-16);
1090        // case 15
1091        test_vieta_roots(1e-15, 1000., 1000., 1000., &[1e-15, 1000., 1000.], 1e-15);
1092        // case 16 no real roots
1093        // case 17
1094        test_vieta(10000., 10001., 10010., 10100., 1e-6);
1095        // case 19
1096        test_vieta_roots(1., 1e30, 1e30, 1e44, &[1., 1e30, 1e44], 1e-16);
1097        // case 20
1098        // FAILS, error too big
1099        test_vieta(1., 1e7, 1e7, 1e14, 1e-7);
1100        // case 21 doesn't pick up double root
1101        // case 22
1102        test_vieta(1., 10., 1e152, 1e154, 3e-16);
1103        // case 23
1104        test_with_roots(
1105            [1., 1., 3. / 8., 1e-3],
1106            &[-0.497314148060048, -0.00268585193995149],
1107            2e-15,
1108        );
1109        // case 24
1110        const S: f64 = 1e30;
1111        test_with_roots(
1112            [-(1. + 1. / S), 1. / S - S * S, S * S + S, -S],
1113            &[-S, 1e-30, 1., S],
1114            2e-16,
1115        );
1116    }
1117
1118    #[test]
1119    fn test_solve_itp() {
1120        let f = |x: f64| x.powi(3) - x - 2.0;
1121        let x = solve_itp(f, 1., 2., 1e-12, 0, 0.2, f(1.), f(2.));
1122        assert!(f(x).abs() < 6e-12);
1123    }
1124
1125    #[test]
1126    fn test_inv_arclen() {
1127        use crate::{ParamCurve, ParamCurveArclen};
1128        let c = crate::CubicBez::new(
1129            (0.0, 0.0),
1130            (100.0 / 3.0, 0.0),
1131            (200.0 / 3.0, 100.0 / 3.0),
1132            (100.0, 100.0),
1133        );
1134        let target = 100.0;
1135        let _ = solve_itp(
1136            |t| c.subsegment(0.0..t).arclen(1e-9) - target,
1137            0.,
1138            1.,
1139            1e-6,
1140            1,
1141            0.2,
1142            -target,
1143            c.arclen(1e-9) - target,
1144        );
1145    }
1146}