1use crate::*;
5use stdlib::num::NonZeroU64;
6
7use arithmetic::store_carry;
8use rounding::NonDigitRoundingData;
9
10
11include!(concat!(env!("OUT_DIR"), "/default_precision.rs"));
13
14
15#[derive(Debug, Clone)]
32pub struct Context {
33 precision: NonZeroU64,
35 rounding: RoundingMode,
37}
38
39impl Context {
40 pub fn new(precision: NonZeroU64, rounding: RoundingMode) -> Self {
42 Context {
43 precision: precision,
44 rounding: rounding,
45 }
46 }
47
48 pub fn with_precision(&self, precision: NonZeroU64) -> Self {
50 Self {
51 precision: precision,
52 ..*self
53 }
54 }
55
56 pub fn with_prec<T: ToPrimitive>(&self, precision: T) -> Option<Self> {
58 precision
59 .to_u64()
60 .and_then(NonZeroU64::new)
61 .map(|prec| self.with_precision(prec))
62 }
63
64 pub fn with_rounding_mode(&self, mode: RoundingMode) -> Self {
66 Self {
67 rounding: mode,
68 ..*self
69 }
70 }
71
72 pub(crate) fn new_truncating(prec: u64) -> Self {
75 Self {
76 rounding: RoundingMode::Down,
77 precision: NonZeroU64::new(prec.max(1)).unwrap(),
78 }
79 }
80
81 pub fn precision(&self) -> NonZeroU64 {
83 self.precision
84 }
85
86 pub fn rounding_mode(&self) -> RoundingMode {
88 self.rounding
89 }
90
91 pub fn round_decimal(&self, n: BigDecimal) -> BigDecimal {
93 n.with_precision_round(self.precision(), self.rounding_mode())
94 }
95
96 pub fn round_decimal_ref<'a, D: Into<BigDecimalRef<'a>>>(&self, n: D) -> BigDecimal {
98 let d = n.into().to_owned();
99 d.with_precision_round(self.precision(), self.rounding_mode())
100 }
101
102 #[allow(dead_code)]
105 pub(crate) fn round_bigint(
106 self, n: num_bigint::BigInt
107 ) -> WithScale<num_bigint::BigInt> {
108 self.rounding.round_bigint_to_prec(n, self.precision)
109 }
110
111 #[allow(dead_code)]
114 pub(crate) fn round_biguint(
115 self, n: num_bigint::BigUint
116 ) -> WithScale<num_bigint::BigUint> {
117 let ndrd = NonDigitRoundingData { mode: self.rounding, sign: Sign::Plus };
118 ndrd.round_biguint_to_prec(n, self.precision)
119 }
120
121 #[allow(dead_code)]
123 pub(crate) fn round_pair(&self, sign: Sign, x: u8, y: u8, trailing_zeros: bool) -> u8 {
124 self.rounding.round_pair(sign, (x, y), trailing_zeros)
125 }
126
127 #[allow(dead_code)]
129 pub(crate) fn round_pair_with_carry(
130 &self,
131 sign: Sign,
132 x: u8,
133 y: u8,
134 trailing_zeros: bool,
135 carry: &mut u8,
136 ) -> u8 {
137 self.rounding.round_pair_with_carry(sign, (x, y), trailing_zeros, carry)
138 }
139
140 pub fn multiply<'a, L, R>(&self, lhs: L, rhs: R) -> BigDecimal
157 where
158 L: Into<BigDecimalRef<'a>>,
159 R: Into<BigDecimalRef<'a>>,
160 {
161 use arithmetic::multiplication::multiply_decimals_with_context;
162
163 let mut result = BigDecimal::zero();
164 multiply_decimals_with_context(&mut result, lhs, rhs, self);
165 result
166 }
167
168 pub fn invert<'a, T: Into<BigDecimalRef<'a>>>(&self, n: T) -> BigDecimal {
184 n.into().inverse_with_context(self)
185 }
186}
187
188impl stdlib::default::Default for Context {
189 fn default() -> Self {
190 Self {
191 precision: NonZeroU64::new(DEFAULT_PRECISION).unwrap(),
192 rounding: RoundingMode::default(),
193 }
194 }
195}
196
197impl Context {
198 pub fn add_refs<'a, 'b, A, B>(&self, a: A, b: B) -> BigDecimal
200 where
201 A: Into<BigDecimalRef<'a>>,
202 B: Into<BigDecimalRef<'b>>,
203 {
204 let mut sum = BigDecimal::zero();
205 self.add_refs_into(a, b, &mut sum);
206 sum
207 }
208
209 pub fn add_refs_into<'a, 'b, A, B>(&self, a: A, b: B, dest: &mut BigDecimal)
211 where
212 A: Into<BigDecimalRef<'a>>,
213 B: Into<BigDecimalRef<'b>>,
214 {
215 let sum = a.into() + b.into();
216 *dest = sum.with_precision_round(self.precision, self.rounding)
217 }
218}
219
220
221#[cfg(test)]
222mod test_context {
223 use super::*;
224
225 #[test]
226 fn constructor_and_setters() {
227 let ctx = Context::default();
228 let c = ctx.with_prec(44).unwrap();
229 assert_eq!(c.precision.get(), 44);
230 assert_eq!(c.rounding, RoundingMode::HalfEven);
231
232 let c = c.with_rounding_mode(RoundingMode::Down);
233 assert_eq!(c.precision.get(), 44);
234 assert_eq!(c.rounding, RoundingMode::Down);
235 }
236
237 #[test]
238 fn sum_two_references() {
239 use stdlib::ops::Neg;
240
241 let ctx = Context::default();
242 let a: BigDecimal = "209682.134972197168613072130300".parse().unwrap();
243 let b: BigDecimal = "3.0782968222271332463325639E-12".parse().unwrap();
244
245 let sum = ctx.add_refs(&a, &b);
246 let expected: BigDecimal =
247 "209682.1349721971716913689525271332463325639".parse().unwrap();
248 assert_eq!(sum, expected);
249
250 let neg_b = b.to_ref().neg();
252
253 let sum = ctx.add_refs(&a, neg_b);
254 let expected: BigDecimal =
255 "209682.1349721971655347753080728667536674361".parse().unwrap();
256 assert_eq!(sum, expected);
257
258 let sum = ctx.with_prec(27).unwrap().with_rounding_mode(RoundingMode::Up).add_refs(&a, neg_b);
259 let expected: BigDecimal =
260 "209682.134972197165534775309".parse().unwrap();
261 assert_eq!(sum, expected);
262 }
263
264 mod round_decimal_ref {
265 use super::*;
266
267 #[test]
268 fn case_bigint_1234567_prec3() {
269 let ctx = Context::default().with_prec(3).unwrap();
270 let i = BigInt::from(1234567);
271 let d = ctx.round_decimal_ref(&i);
272 assert_eq!(d.int_val, 123.into());
273 assert_eq!(d.scale, -4);
274 }
275
276 #[test]
277 fn case_bigint_1234500_prec4_halfup() {
278 let ctx = Context::default()
279 .with_prec(4).unwrap()
280 .with_rounding_mode(RoundingMode::HalfUp);
281 let i = BigInt::from(1234500);
282 let d = ctx.round_decimal_ref(&i);
283 assert_eq!(d.int_val, 1235.into());
284 assert_eq!(d.scale, -3);
285 }
286
287 #[test]
288 fn case_bigint_1234500_prec4_halfeven() {
289 let ctx = Context::default()
290 .with_prec(4).unwrap()
291 .with_rounding_mode(RoundingMode::HalfEven);
292 let i = BigInt::from(1234500);
293 let d = ctx.round_decimal_ref(&i);
294 assert_eq!(d.int_val, 1234.into());
295 assert_eq!(d.scale, -3);
296 }
297
298 #[test]
299 fn case_bigint_1234567_prec10() {
300 let ctx = Context::default().with_prec(10).unwrap();
301 let i = BigInt::from(1234567);
302 let d = ctx.round_decimal_ref(&i);
303 assert_eq!(d.int_val, 1234567000.into());
304 assert_eq!(d.scale, 3);
305 }
306 }
307}