bigdecimal/rounding.rs
1//! Rounding structures and subroutines
2
3use crate::*;
4use crate::arithmetic::{add_carry, store_carry, extend_adding_with_carry};
5use stdlib;
6use stdlib::num::NonZeroU64;
7
8// const DEFAULT_ROUNDING_MODE: RoundingMode = ${RUST_BIGDECIMAL_DEFAULT_ROUNDING_MODE} or HalfUp;
9include!(concat!(env!("OUT_DIR"), "/default_rounding_mode.rs"));
10
11/// Determines how to calculate the last digit of the number
12///
13/// Default rounding mode is `HalfEven`, overwritable at compile-time
14/// by setting the environment-value `RUST_BIGDECIMAL_DEFAULT_ROUNDING_MODE`
15/// to the name of the mode.
16///
17#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq)]
18pub enum RoundingMode {
19 /// Always round away from zero
20 ///
21 ///
22 /// * 5.5 → 6.0
23 /// * 2.5 → 3.0
24 /// * 1.6 → 2.0
25 /// * 1.1 → 2.0
26 /// * -1.1 → -2.0
27 /// * -1.6 → -2.0
28 /// * -2.5 → -3.0
29 /// * -5.5 → -6.0
30 Up,
31
32 /// Always round towards zero
33 ///
34 /// * 5.5 → 5.0
35 /// * 2.5 → 2.0
36 /// * 1.6 → 1.0
37 /// * 1.1 → 1.0
38 /// * -1.1 → -1.0
39 /// * -1.6 → -1.0
40 /// * -2.5 → -2.0
41 /// * -5.5 → -5.0
42 Down,
43
44 /// Towards +∞
45 ///
46 /// * 5.5 → 6.0
47 /// * 2.5 → 3.0
48 /// * 1.6 → 2.0
49 /// * 1.1 → 2.0
50 /// * -1.1 → -1.0
51 /// * -1.6 → -1.0
52 /// * -2.5 → -2.0
53 /// * -5.5 → -5.0
54 Ceiling,
55
56 /// Towards -∞
57 ///
58 /// * 5.5 → 5.0
59 /// * 2.5 → 2.0
60 /// * 1.6 → 1.0
61 /// * 1.1 → 1.0
62 /// * -1.1 → -2.0
63 /// * -1.6 → -2.0
64 /// * -2.5 → -3.0
65 /// * -5.5 → -6.0
66 Floor,
67
68 /// Round to 'nearest neighbor', or up if ending decimal is 5
69 ///
70 /// * 5.5 → 6.0
71 /// * 2.5 → 3.0
72 /// * 1.6 → 2.0
73 /// * 1.1 → 1.0
74 /// * -1.1 → -1.0
75 /// * -1.6 → -2.0
76 /// * -2.5 → -3.0
77 /// * -5.5 → -6.0
78 HalfUp,
79
80 /// Round to 'nearest neighbor', or down if ending decimal is 5
81 ///
82 /// * 5.5 → 5.0
83 /// * 2.5 → 2.0
84 /// * 1.6 → 2.0
85 /// * 1.1 → 1.0
86 /// * -1.1 → -1.0
87 /// * -1.6 → -2.0
88 /// * -2.5 → -2.0
89 /// * -5.5 → -5.0
90 HalfDown,
91
92 /// Round to 'nearest neighbor', if equidistant, round towards
93 /// nearest even digit
94 ///
95 /// * 5.5 → 6.0
96 /// * 2.5 → 2.0
97 /// * 1.6 → 2.0
98 /// * 1.1 → 1.0
99 /// * -1.1 → -1.0
100 /// * -1.6 → -2.0
101 /// * -2.5 → -2.0
102 /// * -5.5 → -6.0
103 ///
104 HalfEven,
105}
106
107
108impl RoundingMode {
109 /// Perform the rounding operation
110 ///
111 /// Parameters
112 /// ----------
113 /// * sign (Sign) - Sign of the number to be rounded
114 /// * pair (u8, u8) - The two digits in question to be rounded.
115 /// i.e. to round 0.345 to two places, you would pass (4, 5).
116 /// As decimal digits, they
117 /// must be less than ten!
118 /// * trailing_zeros (bool) - True if all digits after the pair are zero.
119 /// This has an effect if the right hand digit is 0 or 5.
120 ///
121 /// Returns
122 /// -------
123 /// Returns the first number of the pair, rounded. The sign is not preserved.
124 ///
125 /// Examples
126 /// --------
127 /// - To round 2341, pass in `Plus, (4, 1), true` → get 4 or 5 depending on scheme
128 /// - To round -0.1051, to two places: `Minus, (0, 5), false` → returns either 0 or 1
129 /// - To round -0.1, pass in `true, (0, 1)` → returns either 0 or 1
130 ///
131 /// Calculation of pair of digits from full number, and the replacement of that number
132 /// should be handled separately
133 ///
134 pub fn round_pair(&self, sign: Sign, pair: (u8, u8), trailing_zeros: bool) -> u8 {
135 use self::RoundingMode::*;
136 use stdlib::cmp::Ordering::*;
137
138 let (lhs, rhs) = pair;
139 // if all zero after digit, never round
140 if rhs == 0 && trailing_zeros {
141 return lhs;
142 }
143 let up = lhs + 1;
144 let down = lhs;
145 match (*self, rhs.cmp(&5)) {
146 (Up, _) => up,
147 (Down, _) => down,
148 (Floor, _) => if sign == Sign::Minus { up } else { down },
149 (Ceiling, _) => if sign == Sign::Minus { down } else { up },
150 (_, Less) => down,
151 (_, Greater) => up,
152 (_, Equal) if !trailing_zeros => up,
153 (HalfUp, Equal) => up,
154 (HalfDown, Equal) => down,
155 (HalfEven, Equal) => if lhs % 2 == 0 { down } else { up },
156 }
157 }
158
159 /// Round digits, and if rounded up to 10, store 1 in carry and return zero
160 pub(crate) fn round_pair_with_carry(
161 &self,
162 sign: Sign,
163 pair: (u8, u8),
164 trailing_zeros: bool,
165 carry: &mut u8,
166 ) -> u8 {
167 let r = self.round_pair(sign, pair, trailing_zeros);
168 store_carry(r, carry)
169 }
170
171 /// Round value at particular digit, returning replacement digit
172 ///
173 /// Parameters
174 /// ----------
175 /// * at_digit (NonZeroU8) - 0-based index of digit at which to round.
176 /// 0 would be the first digit, and would
177 ///
178 /// * sign (Sign) - Sign of the number to be rounded
179 /// * value (u32) - The number containing digits to be rounded.
180 /// * trailing_zeros (bool) - True if all digits after the value are zero.
181 ///
182 /// Returns
183 /// -------
184 /// Returns the first number of the pair, rounded. The sign is not preserved.
185 ///
186 /// Examples
187 /// --------
188 /// - To round 823418, at digit-index 3: `3, Plus, 823418, true` → 823000 or 824000, depending on scheme
189 /// - To round -100205, at digit-index 1: `1, Minus, 100205, true` → 100200 or 100210
190 ///
191 /// Calculation of pair of digits from full number, and the replacement of that number
192 /// should be handled separately
193 ///
194 pub fn round_u32(
195 &self,
196 at_digit: stdlib::num::NonZeroU8,
197 sign: Sign,
198 value: u32,
199 trailing_zeros: bool,
200 ) -> u32 {
201 let shift = 10u32.pow(at_digit.get() as u32 - 1);
202 let splitter = shift * 10;
203
204 // split 'value' into high and low
205 let (top, bottom) = num_integer::div_rem(value, splitter);
206 let lhs = (top % 10) as u8;
207 let (rhs, remainder) = num_integer::div_rem(bottom, shift);
208 let pair = (lhs, rhs as u8);
209 let rounded = self.round_pair(sign, pair, trailing_zeros && remainder == 0);
210
211 // replace low digit with rounded value
212 let full = top - lhs as u32 + rounded as u32;
213
214 // shift rounded value back to position
215 full * splitter
216 }
217
218 /// Round the bigint to prec digits
219 pub(crate) fn round_bigint_to_prec(
220 self, n: num_bigint::BigInt, prec: NonZeroU64
221 ) -> WithScale<num_bigint::BigInt> {
222 let (sign, mut biguint) = n.into_parts();
223
224 let ndrd = NonDigitRoundingData { mode: self, sign };
225 let ndigits = round_biguint_inplace(&mut biguint, prec, ndrd);
226
227 let result = BigInt::from_biguint(sign, biguint);
228 WithScale::from((result, -ndigits))
229 }
230
231 /// Hint used to skip calculating trailing_zeros if they don't matter
232 fn needs_trailing_zeros(&self, insig_digit: u8) -> bool {
233 use RoundingMode::*;
234
235 // only need trailing zeros if the rounding digit is 0 or 5
236 if matches!(self, HalfUp | HalfDown | HalfEven) {
237 insig_digit == 5
238 } else {
239 insig_digit == 0
240 }
241 }
242
243}
244
245/// Return compile-time constant default rounding mode
246///
247/// Defined by RUST_BIGDECIMAL_DEFAULT_ROUNDING_MODE at compile time
248///
249impl Default for RoundingMode {
250 fn default() -> Self {
251 DEFAULT_ROUNDING_MODE
252 }
253}
254
255
256/// All non-digit information required to round digits
257///
258/// Just the mode and the sign.
259///
260#[derive(Debug, Clone, Copy)]
261pub(crate) struct NonDigitRoundingData {
262 /// Rounding mode
263 pub mode: RoundingMode,
264 /// Sign of digits
265 pub sign: Sign,
266}
267
268impl NonDigitRoundingData {
269 /// Round pair of digits, storing overflow (10) in the carry
270 pub fn round_pair(&self, pair: (u8, u8), trailing_zeros: bool) -> u8 {
271 self.mode.round_pair(self.sign, pair, trailing_zeros)
272 }
273
274 /// round-pair with carry-digits
275 pub fn round_pair_with_carry(&self, pair: (u8, u8), trailing_zeros: bool, carry: &mut u8) -> u8 {
276 self.mode.round_pair_with_carry(self.sign, pair, trailing_zeros, carry)
277 }
278
279 /// Use sign and default rounding mode
280 pub fn default_with_sign(sign: Sign) -> Self {
281 NonDigitRoundingData { sign, mode: RoundingMode::default() }
282 }
283
284 /// Round BigUint to requested precision, using mode and sign in self
285 ///
286 /// Returns the biguint with at most 'prec' digits, and scale
287 /// indicating how many decimal digits were removed.
288 ///
289 pub(crate) fn round_biguint_to_prec(
290 self, mut n: num_bigint::BigUint, prec: NonZeroU64
291 ) -> WithScale<num_bigint::BigUint> {
292 let ndigits = round_biguint_inplace(&mut n, prec, self);
293 WithScale::from((n, -ndigits))
294 }
295
296}
297
298
299/// Relevant information about insignificant digits, used for rounding
300///
301/// If rounding at indicated point:
302///
303/// ```txt
304/// aaaaizzzzzzzz
305/// ^
306/// ```
307///
308/// 'a' values are significant, 'i' is the insignificant digit,
309/// and trailing_zeros is true if all 'z' are 0.
310///
311#[derive(Debug,Clone,Copy)]
312pub(crate) struct InsigData {
313 /// highest insignificant digit
314 pub digit: u8,
315
316 /// true if all digits more insignificant than 'digit' is zero
317 ///
318 /// This is only useful if relevant for the rounding mode, it
319 /// may be 'wrong' in these cases.
320 pub trailing_zeros: bool,
321
322 /// rounding-mode and sign
323 pub rounding_data: NonDigitRoundingData
324}
325
326#[allow(dead_code)]
327impl InsigData {
328 /// Build from insig data and lazily calculated trailing-zeros callable
329 pub fn from_digit_and_lazy_trailing_zeros(
330 rounder: NonDigitRoundingData,
331 insig_digit: u8,
332 calc_trailing_zeros: impl FnOnce() -> bool
333 ) -> Self {
334 Self {
335 digit: insig_digit,
336 trailing_zeros: rounder.mode.needs_trailing_zeros(insig_digit) && calc_trailing_zeros(),
337 rounding_data: rounder,
338 }
339 }
340
341 /// Build from slice of insignificant little-endian digits
342 pub fn from_digit_slice(rounder: NonDigitRoundingData, digits: &[u8]) -> Self {
343 match digits.split_last() {
344 Some((&d0, trailing)) => {
345 Self::from_digit_and_lazy_trailing_zeros(
346 rounder, d0, || trailing.iter().all(Zero::is_zero)
347 )
348 }
349 None => {
350 Self {
351 digit: 0,
352 trailing_zeros: true,
353 rounding_data: rounder,
354 }
355 }
356 }
357 }
358
359 /// from sum of overlapping digits, (a is longer than b)
360 pub fn from_overlapping_digits_backward_sum(
361 rounder: NonDigitRoundingData,
362 mut a_digits: stdlib::iter::Rev<stdlib::slice::Iter<u8>>,
363 mut b_digits: stdlib::iter::Rev<stdlib::slice::Iter<u8>>,
364 carry: &mut u8,
365 ) -> Self {
366 debug_assert!(a_digits.len() >= b_digits.len());
367 debug_assert_eq!(carry, &0);
368
369 // most-significant insignificant digit
370 let insig_digit;
371 match (a_digits.next(), b_digits.next()) {
372 (Some(a), Some(b)) => {
373 // store 'full', initial sum, we will handle carry below
374 insig_digit = a + b;
375 }
376 (Some(d), None) | (None, Some(d)) => {
377 insig_digit = *d;
378 }
379 (None, None) => {
380 // both digit slices were empty; all zeros
381 return Self {
382 digit: 0,
383 trailing_zeros: true,
384 rounding_data: rounder,
385 };
386 }
387 };
388
389 // find first non-nine value
390 let mut sum = 9;
391 while sum == 9 {
392 let next_a = a_digits.next().unwrap_or(&0);
393 let next_b = b_digits.next().unwrap_or(&0);
394 sum = next_a + next_b;
395 }
396
397 // if previous sum was greater than ten,
398 // the one would carry through all the 9s
399 let sum = store_carry(sum, carry);
400
401 // propagate carry to the highest insignificant digit
402 let insig_digit = add_carry(insig_digit, carry);
403
404 // if the last 'sum' value isn't zero, or if any remaining
405 // digit is not zero, then it's not trailing zeros
406 let trailing_zeros = sum == 0
407 && rounder.mode.needs_trailing_zeros(insig_digit)
408 && a_digits.all(Zero::is_zero)
409 && b_digits.all(Zero::is_zero);
410
411 Self {
412 digit: insig_digit,
413 trailing_zeros: trailing_zeros,
414 rounding_data: rounder,
415 }
416 }
417
418 pub fn round_digit(&self, digit: u8) -> u8 {
419 self.rounding_data.round_pair((digit, self.digit), self.trailing_zeros)
420 }
421
422 pub fn round_digit_with_carry(&self, digit: u8, carry: &mut u8) -> u8 {
423 self.rounding_data.round_pair_with_carry((digit, self.digit), self.trailing_zeros, carry)
424 }
425
426 pub fn round_slice_into(&self, dest: &mut Vec<u8>, digits: &[u8]) {
427 let (&d0, rest) = digits.split_first().unwrap_or((&0, &[]));
428 let digits = rest.iter().copied();
429 let mut carry = 0;
430 let r0 = self.round_digit_with_carry(d0, &mut carry);
431 dest.push(r0);
432 extend_adding_with_carry(dest, digits, &mut carry);
433 if !carry.is_zero() {
434 dest.push(carry);
435 }
436 }
437
438 #[allow(dead_code)]
439 pub fn round_slice_into_with_carry(&self, dest: &mut Vec<u8>, digits: &[u8], carry: &mut u8) {
440 let (&d0, rest) = digits.split_first().unwrap_or((&0, &[]));
441 let digits = rest.iter().copied();
442 let r0 = self.round_digit_with_carry(d0, carry);
443 dest.push(r0);
444
445 extend_adding_with_carry(dest, digits, carry);
446 }
447}
448
449/// Round BigUint n to 'prec' digits
450fn round_biguint_inplace(
451 n: &mut num_bigint::BigUint,
452 prec: NonZeroU64,
453 rounder: NonDigitRoundingData,
454) -> i64 {
455 use arithmetic::modulo::{mod_ten_2p64_le, mod_100_uint};
456 use arithmetic::decimal::count_digits_biguint;
457
458 let digit_count = count_digits_biguint(n);
459 let digits_to_remove = digit_count.saturating_sub(prec.get());
460 if digits_to_remove == 0 {
461 return 0;
462 }
463
464 if digits_to_remove == 1 {
465 let insig_digit = mod_ten_2p64_le(n.iter_u64_digits());
466 *n /= 10u8;
467 let sig_digit = mod_ten_2p64_le(n.iter_u64_digits());
468 let rounded_digit = rounder.round_pair((sig_digit, insig_digit), true);
469 *n += rounded_digit - sig_digit;
470 if rounded_digit != 10 {
471 return 1;
472 }
473 let digit_count = count_digits_biguint(n);
474 if digit_count == prec.get() {
475 return 1;
476 }
477 debug_assert_eq!(digit_count, prec.get() + 1);
478 *n /= 10u8;
479 return 2;
480 }
481
482 let shifter = ten_to_the_uint(digits_to_remove - 1);
483 let low_digits = &(*n) % &shifter;
484 let trailing_zeros = low_digits.is_zero();
485
486 *n /= &shifter;
487 let u = mod_100_uint(n);
488 let (sig_digit, insig_digit) = u.div_rem(&10);
489 let rounded_digit = rounder.round_pair((sig_digit, insig_digit), trailing_zeros);
490 *n /= 10u8;
491 *n += rounded_digit - sig_digit;
492
493 if rounded_digit != 10 {
494 return digits_to_remove as i64;
495 }
496
497 let digit_count = count_digits_biguint(n);
498 if digit_count == prec.get() {
499 return digits_to_remove as i64;
500 }
501
502 debug_assert_eq!(digit_count, prec.get() + 1);
503
504 // shift by another digit. Overflow means all significant
505 // digits were nines, so no need to re-round
506 *n /= 10u8;
507 return digits_to_remove as i64 + 1;
508}
509
510
511#[cfg(test)]
512include!("rounding.tests.rs");