bigdecimal/
impl_cmp.rs

1//! Implementation of comparison operations
2//!
3//! Comparisons between decimals and decimal refs
4//! are not directly supported as we lose some type
5//! inference features at the savings of a single
6//! '&' character.
7//!
8//! &BigDecimal and BigDecimalRef are comparable.
9//!
10
11use crate::*;
12
13use stdlib::cmp::Ordering;
14use stdlib::iter;
15
16impl PartialEq for BigDecimal {
17    fn eq(&self, rhs: &BigDecimal) -> bool {
18        self.to_ref() == rhs.to_ref()
19    }
20}
21
22impl<'rhs, T> PartialEq<T> for BigDecimalRef<'_>
23where
24    T: Into<BigDecimalRef<'rhs>> + Copy,
25{
26    fn eq(&self, rhs: &T) -> bool {
27        let rhs: BigDecimalRef<'rhs> = (*rhs).into();
28        check_equality_bigdecimal_ref(*self, rhs)
29    }
30}
31
32fn check_equality_bigdecimal_ref(lhs: BigDecimalRef, rhs: BigDecimalRef) -> bool {
33    match (lhs.sign(), rhs.sign()) {
34        // both zero
35        (Sign::NoSign, Sign::NoSign) => return true,
36        // signs are different
37        (a, b) if a != b => return false,
38        // signs are same, do nothing
39        _ => {}
40    }
41
42    let unscaled_int;
43    let scaled_int;
44    let trailing_zero_count;
45    match arithmetic::checked_diff(lhs.scale, rhs.scale) {
46        (Ordering::Equal, _) => {
47            return lhs.digits == rhs.digits;
48        }
49        (Ordering::Greater, Some(scale_diff)) => {
50            unscaled_int = lhs.digits;
51            scaled_int = rhs.digits;
52            trailing_zero_count = scale_diff;
53        }
54        (Ordering::Less, Some(scale_diff)) => {
55            unscaled_int = rhs.digits;
56            scaled_int = lhs.digits;
57            trailing_zero_count = scale_diff;
58        }
59        _ => {
60            // all other cases imply overflow in difference of scale,
61            // numbers must not be equal
62            return false;
63        }
64    }
65
66    debug_assert_ne!(trailing_zero_count, 0);
67
68    // test if unscaled_int is guaranteed to be less than
69    // scaled_int*10^trailing_zero_count based on highest bit
70    if highest_bit_lessthan_scaled(unscaled_int, scaled_int, trailing_zero_count) {
71        return false;
72    }
73
74    // try compare without allocating
75    if trailing_zero_count < 20 {
76        let pow = ten_to_the_u64(trailing_zero_count as u8);
77
78        let mut a_digits = unscaled_int.iter_u32_digits();
79        let mut b_digits = scaled_int.iter_u32_digits();
80
81        let mut carry = 0;
82        loop {
83            match (a_digits.next(), b_digits.next()) {
84                (Some(next_a), Some(next_b)) => {
85                    let wide_b = match (next_b as u64).checked_mul(pow) {
86                        Some(tmp) => tmp + carry,
87                        None => break,
88                    };
89
90                    let true_b = wide_b as u32;
91
92                    if next_a != true_b {
93                        return false;
94                    }
95
96                    carry = wide_b >> 32;
97                }
98                (None, Some(_)) => {
99                    return false;
100                }
101                (Some(a_digit), None) => {
102                    if a_digit != (carry as u32) {
103                        return false;
104                    }
105                    carry = 0;
106                }
107                (None, None) => {
108                    return carry == 0;
109                }
110            }
111        }
112
113        // we broke out of loop due to overflow - compare via allocation
114        let scaled_int = scaled_int * pow;
115        return &scaled_int == unscaled_int;
116    }
117
118    let trailing_zero_count = trailing_zero_count.to_usize().unwrap();
119    let unscaled_digits = unscaled_int.to_radix_le(10);
120
121    if trailing_zero_count > unscaled_digits.len() {
122        return false;
123    }
124
125    // split into digits below the other value, and digits overlapping
126    let (low_digits, overlap_digits) = unscaled_digits.split_at(trailing_zero_count);
127
128    // if any of the low digits are zero, they are not equal
129    if low_digits.iter().any(|&d| d != 0) {
130        return false;
131    }
132
133    let scaled_digits = scaled_int.to_radix_le(10);
134
135    // different lengths with trailing zeros
136    if overlap_digits.len() != scaled_digits.len() {
137        return false;
138    }
139
140    // return true if all digits are the same
141    overlap_digits.iter().zip(scaled_digits.iter()).all(|(digit_a, digit_b)| digit_a == digit_b)
142}
143
144
145impl PartialOrd for BigDecimal {
146    #[inline]
147    fn partial_cmp(&self, other: &BigDecimal) -> Option<Ordering> {
148        Some(self.cmp(other))
149    }
150}
151
152impl PartialOrd for BigDecimalRef<'_> {
153    fn partial_cmp(&self, other: &BigDecimalRef<'_>) -> Option<Ordering> {
154        Some(self.cmp(other))
155    }
156}
157
158
159impl Ord for BigDecimal {
160    #[inline]
161    fn cmp(&self, other: &BigDecimal) -> Ordering {
162        self.to_ref().cmp(&other.to_ref())
163    }
164}
165
166impl Ord for BigDecimalRef<'_> {
167    /// Complete ordering implementation for BigDecimal
168    ///
169    /// # Example
170    ///
171    /// ```
172    /// use std::str::FromStr;
173    ///
174    /// let a = bigdecimal::BigDecimal::from_str("-1").unwrap();
175    /// let b = bigdecimal::BigDecimal::from_str("1").unwrap();
176    /// assert!(a < b);
177    /// assert!(b > a);
178    /// let c = bigdecimal::BigDecimal::from_str("1").unwrap();
179    /// assert!(b >= c);
180    /// assert!(c >= b);
181    /// let d = bigdecimal::BigDecimal::from_str("10.0").unwrap();
182    /// assert!(d > c);
183    /// let e = bigdecimal::BigDecimal::from_str(".5").unwrap();
184    /// assert!(e < c);
185    /// ```
186    #[inline]
187    fn cmp(&self, other: &BigDecimalRef) -> Ordering {
188        use Ordering::*;
189
190        let scmp = self.sign().cmp(&other.sign());
191        if scmp != Ordering::Equal {
192            return scmp;
193        }
194
195        if self.sign() == Sign::NoSign {
196            return Ordering::Equal;
197        }
198
199        let result = match arithmetic::checked_diff(self.scale, other.scale) {
200            (Greater, Some(scale_diff)) | (Equal, Some(scale_diff)) => {
201                compare_scaled_biguints(self.digits, other.digits, scale_diff)
202            }
203            (Less, Some(scale_diff)) => {
204                compare_scaled_biguints(other.digits, self.digits, scale_diff).reverse()
205            }
206            (res, None) => {
207                // The difference in scale does not fit in a u64,
208                // we can safely assume the value of digits do not matter
209                // (unless we have a 2^64 (i.e. ~16 exabyte) long number
210
211                // larger scale means smaller number, reverse this ordering
212                res.reverse()
213            }
214        };
215
216        if other.sign == Sign::Minus {
217            result.reverse()
218        } else {
219            result
220        }
221    }
222}
223
224
225/// compare scaled uints: a <=> b * 10^{scale_diff}
226///
227fn compare_scaled_biguints(a: &BigUint, b: &BigUint, scale_diff: u64) -> Ordering {
228    use Ordering::*;
229
230    if scale_diff == 0 {
231        return a.cmp(b);
232    }
233
234    // check if highest bit of a is less than b * 10^scale_diff
235    if highest_bit_lessthan_scaled(a, b, scale_diff) {
236        return Ordering::Less;
237    }
238
239    // if biguints fit it u64 or u128, compare using those (avoiding allocations)
240    if let Some(result) = compare_scalar_biguints(a, b, scale_diff) {
241        return result;
242    }
243
244    let a_digit_count = count_decimal_digits_uint(a);
245    let b_digit_count = count_decimal_digits_uint(b);
246
247    let digit_count_cmp = a_digit_count.cmp(&(b_digit_count + scale_diff));
248    if digit_count_cmp != Equal {
249        return digit_count_cmp;
250    }
251
252    let a_digits = a.to_radix_le(10);
253    let b_digits = b.to_radix_le(10);
254
255    debug_assert_eq!(a_digits.len(), a_digit_count as usize);
256    debug_assert_eq!(b_digits.len(), b_digit_count as usize);
257
258    let mut a_it = a_digits.iter().rev();
259    let mut b_it = b_digits.iter().rev();
260
261    loop {
262        match (a_it.next(), b_it.next()) {
263            (Some(ai), Some(bi)) => {
264                match ai.cmp(bi) {
265                    Equal => continue,
266                    result => return result,
267                }
268            }
269            (Some(&ai), None) => {
270                if ai == 0 && a_it.all(Zero::is_zero) {
271                    return Equal;
272                } else {
273                    return Greater;
274                }
275            }
276            (None, Some(&bi)) => {
277                if bi == 0 && b_it.all(Zero::is_zero) {
278                    return Equal;
279                } else {
280                    return Less;
281                }
282            }
283            (None, None) => {
284                return Equal;
285            }
286        }
287    }
288}
289
290/// Try fitting biguints into primitive integers, using those for ordering if possible
291fn compare_scalar_biguints(a: &BigUint, b: &BigUint, scale_diff: u64) -> Option<Ordering> {
292    let scale_diff = scale_diff.to_usize()?;
293
294    // try u64, then u128
295    compare_scaled_uints::<u64>(a, b, scale_diff)
296    .or_else(|| compare_scaled_uints::<u128>(a, b, scale_diff))
297}
298
299/// Implementation comparing biguints cast to generic type
300fn compare_scaled_uints<'a, T>(
301    a: &'a BigUint,
302    b: &'a BigUint,
303    scale_diff: usize,
304) -> Option<Ordering>
305where
306    T: num_traits::PrimInt + TryFrom<&'a BigUint>,
307{
308    let ten = T::from(10).unwrap();
309
310    let a = T::try_from(a).ok();
311    let b = T::try_from(b).ok().and_then(
312                |b| num_traits::checked_pow(ten, scale_diff).and_then(
313                    |p| b.checked_mul(&p)));
314
315    match (a, b) {
316        (Some(a), Some(scaled_b)) => Some(a.cmp(&scaled_b)),
317        // if scaled_b doesn't fit in size T, while 'a' does, then a is certainly less
318        (Some(_), None) => Some(Ordering::Less),
319        // if a doesn't fit in size T, while 'scaled_b' does, then a is certainly greater
320        (None, Some(_)) => Some(Ordering::Greater),
321        // neither fits, cannot determine relative size
322        (None, None) => None,
323    }
324}
325
326/// Return highest_bit(a) < highest_bit(b * 10^{scale})
327///
328/// Used for optimization when comparing scaled integers
329///
330/// ```math
331/// a < b * 10^{scale}
332/// log(a) < log(b) + scale * log(10)
333/// ```
334///
335fn highest_bit_lessthan_scaled(a: &BigUint, b: &BigUint, scale: u64) -> bool {
336    let a_bits = a.bits();
337    let b_bits = b.bits();
338    if a_bits < b_bits {
339        return true;
340    }
341    let log_scale = LOG2_10 * scale as f64;
342    match b_bits.checked_add(log_scale as u64) {
343        Some(scaled_b_bit) => a_bits < scaled_b_bit,
344        None => true, // overflowing u64 means we are definitely bigger
345    }
346}
347
348macro_rules! impl_prim_cmp {
349    ($t:ty) => {
350        impl PartialOrd<$t> for BigDecimal {
351            fn partial_cmp(&self, other: &$t) -> Option<Ordering> {
352                self.to_ref().partial_cmp(other)
353            }
354        }
355
356        impl PartialEq<$t> for BigDecimal {
357            fn eq(&self, rhs: &$t) -> bool {
358                self.to_ref().eq(rhs)
359            }
360        }
361
362        impl PartialOrd<$t> for &BigDecimal {
363            fn partial_cmp(&self, other: &$t) -> Option<Ordering> {
364                self.to_ref().partial_cmp(other)
365            }
366        }
367
368        impl PartialOrd<$t> for BigDecimalRef<'_>
369        {
370            fn partial_cmp(&self, other: &$t) -> Option<Ordering> {
371                let rhs = BigDecimal::from(other);
372                self.partial_cmp(&rhs.to_ref())
373            }
374        }
375
376        impl PartialEq<$t> for &BigDecimal {
377            fn eq(&self, rhs: &$t) -> bool {
378                self.to_ref().eq(rhs)
379            }
380        }
381
382        impl PartialEq<$t> for BigDecimalRef<'_>
383        {
384            fn eq(&self, rhs: &$t) -> bool {
385                let rhs = BigDecimal::from(rhs);
386                check_equality_bigdecimal_ref(*self, rhs.to_ref())
387            }
388        }
389    };
390}
391
392impl_prim_cmp!(u8);
393impl_prim_cmp!(u16);
394impl_prim_cmp!(u32);
395impl_prim_cmp!(u64);
396impl_prim_cmp!(u128);
397impl_prim_cmp!(i8);
398impl_prim_cmp!(i16);
399impl_prim_cmp!(i32);
400impl_prim_cmp!(i64);
401impl_prim_cmp!(i128);
402
403
404#[cfg(test)]
405mod test {
406    use super::*;
407
408    mod compare_scaled_biguints {
409        use super::*;
410
411        macro_rules! impl_test {
412            ($name:ident: $a:literal > $b:literal e $e:literal) => {
413                impl_test!($name: $a Greater $b e $e);
414            };
415            ($name:ident: $a:literal < $b:literal e $e:literal) => {
416                impl_test!($name: $a Less $b e $e);
417            };
418            ($name:ident: $a:literal = $b:literal e $e:literal) => {
419                impl_test!($name: $a Equal $b e $e);
420            };
421            ($name:ident: $a:literal $op:ident $b:literal e $e:literal) => {
422                #[test]
423                fn $name() {
424                    let a: BigUint = $a.parse().unwrap();
425                    let b: BigUint = $b.parse().unwrap();
426
427                    let result = compare_scaled_biguints(&a, &b, $e);
428                    assert_eq!(result, Ordering::$op);
429                }
430            };
431        }
432
433        impl_test!(case_500_51e1: "500" < "51" e 1);
434        impl_test!(case_500_44e1: "500" > "44" e 1);
435        impl_test!(case_5000_50e2: "5000" = "50" e 2);
436        impl_test!(case_1234e9_12345e9: "1234000000000" < "12345" e 9);
437        impl_test!(case_1116xx459_759xx717e2: "1116386634271380982470843247639640260491505327092723527088459" < "759522625769651746138617259189939751893902453291243506584717" e 2);
438    }
439
440    /// Test that large-magnitidue exponentials will not crash
441    #[test]
442    fn test_cmp_on_exp_boundaries() {
443        let a = BigDecimal::new(1.into(), i64::MAX);
444        let z = BigDecimal::new(1.into(), i64::MIN);
445        assert_ne!(a, z);
446        assert_ne!(z, a);
447
448        assert!(a < z);
449
450        assert_eq!(a, a);
451        assert_eq!(z, z);
452    }
453
454    mod ord {
455        use super::*;
456
457        macro_rules! impl_test {
458            ($name:ident: $a:literal < $b:literal) => {
459                #[test]
460                fn $name() {
461                    let a: BigDecimal = $a.parse().unwrap();
462                    let b: BigDecimal = $b.parse().unwrap();
463
464                    assert!(&a < &b);
465                    assert!(&b > &a);
466                    assert_ne!(a, b);
467                }
468            };
469        }
470
471        impl_test!(case_diff_signs: "-1" < "1");
472        impl_test!(case_n1_0: "-1" < "0");
473        impl_test!(case_0_1: "0" < "1");
474        impl_test!(case_1d2345_1d2346: "1.2345" < "1.2346");
475        impl_test!(case_compare_extreme: "1e-9223372036854775807" < "1");
476        impl_test!(case_compare_extremes: "1e-9223372036854775807" < "1e9223372036854775807");
477        impl_test!(case_small_difference: "472697816888807260.1604" < "472697816888807260.16040000000000000000001");
478        impl_test!(case_very_small_diff: "-1.0000000000000000000000000000000000000000000000000001" < "-1");
479
480        impl_test!(case_1_2p128: "1" < "340282366920938463463374607431768211455");
481        impl_test!(case_1_1e39: "1000000000000000000000000000000000000000" < "1e41");
482
483        impl_test!(case_1d414xxx573: "1.414213562373095048801688724209698078569671875376948073176679730000000000000000000000000000000000000" < "1.41421356237309504880168872420969807856967187537694807317667974000000000");
484        impl_test!(case_11d414xxx573: "1.414213562373095048801688724209698078569671875376948073176679730000000000000000000000000000000000000" < "11.41421356237309504880168872420969807856967187537694807317667974000000000");
485    }
486
487    mod eq {
488        use super::*;
489
490        macro_rules! impl_test {
491            ($name:ident: $a:literal = $b:literal) => {
492                #[test]
493                fn $name() {
494                    let a: BigDecimal = $a.parse().unwrap();
495                    let b: BigDecimal = $b.parse().unwrap();
496
497                    assert_eq!(&a, &b);
498                    assert_eq!(a, b);
499                }
500            };
501        }
502
503        impl_test!(case_zero: "0" = "0.00");
504        impl_test!(case_1_1d00: "1" = "1.00");
505        impl_test!(case_n1_n1000en3: "-1" = "-1000e-3");
506        impl_test!(case_0d000034500_345en7: "0.000034500" = "345e-7");
507    }
508
509    #[test]
510    fn test_borrow_neg_cmp() {
511        let x: BigDecimal = "1514932018891593.916341142773".parse().unwrap();
512        let y: BigDecimal = "1514932018891593916341142773e-12".parse().unwrap();
513
514        assert_eq!(x, y);
515
516        let x_ref = x.to_ref();
517        assert_eq!(x_ref, &y);
518        assert_ne!(x_ref.neg(), x_ref);
519        assert_eq!(x_ref.neg().neg(), x_ref);
520    }
521
522    mod cmp_prim {
523        use super::*;
524
525        #[test]
526        fn cmp_zero_u8() {
527            let n = BigDecimal::zero();
528            assert!(&n == 0u8);
529        }
530    }
531
532    #[cfg(property_tests)]
533    mod prop {
534        use super::*;
535        use proptest::prelude::*;
536
537        proptest! {
538            #![proptest_config(ProptestConfig { cases: 5000, ..Default::default() })]
539
540            #[test]
541            fn cmp_matches_f64(
542                f in proptest::num::f64::NORMAL | proptest::num::f64::SUBNORMAL | proptest::num::f64::ZERO,
543                g in proptest::num::f64::NORMAL | proptest::num::f64::SUBNORMAL | proptest::num::f64::ZERO
544            ) {
545                let a: BigDecimal = BigDecimal::from_f64(f).unwrap();
546                let b: BigDecimal = BigDecimal::from_f64(g).unwrap();
547
548                let expected = PartialOrd::partial_cmp(&f, &g).unwrap();
549                let value = a.cmp(&b);
550
551                prop_assert_eq!(expected, value)
552            }
553        }
554    }
555}