rand/distr/
uniform_float.rs

1// Copyright 2018-2020 Developers of the Rand project.
2// Copyright 2017 The Rust Project Developers.
3//
4// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
5// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
6// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
7// option. This file may not be copied, modified, or distributed
8// except according to those terms.
9
10//! `UniformFloat` implementation
11
12use super::{Error, SampleBorrow, SampleUniform, UniformSampler};
13use crate::distr::float::IntoFloat;
14use crate::distr::utils::{BoolAsSIMD, FloatAsSIMD, FloatSIMDUtils, IntAsSIMD};
15use crate::Rng;
16
17#[cfg(feature = "simd_support")]
18use core::simd::prelude::*;
19// #[cfg(feature = "simd_support")]
20// use core::simd::{LaneCount, SupportedLaneCount};
21
22#[cfg(feature = "serde")]
23use serde::{Deserialize, Serialize};
24
25/// The back-end implementing [`UniformSampler`] for floating-point types.
26///
27/// Unless you are implementing [`UniformSampler`] for your own type, this type
28/// should not be used directly, use [`Uniform`] instead.
29///
30/// # Implementation notes
31///
32/// `UniformFloat` implementations convert RNG output to a float in the range
33/// `[1, 2)` via transmutation, map this to `[0, 1)`, then scale and translate
34/// to the desired range. Values produced this way have what equals 23 bits of
35/// random digits for an `f32` and 52 for an `f64`.
36///
37/// # Bias and range errors
38///
39/// Bias may be expected within the least-significant bit of the significand.
40/// It is not guaranteed that exclusive limits of a range are respected; i.e.
41/// when sampling the range `[a, b)` it is not guaranteed that `b` is never
42/// sampled.
43///
44/// [`new`]: UniformSampler::new
45/// [`new_inclusive`]: UniformSampler::new_inclusive
46/// [`StandardUniform`]: crate::distr::StandardUniform
47/// [`Uniform`]: super::Uniform
48#[derive(Clone, Copy, Debug, PartialEq)]
49#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
50pub struct UniformFloat<X> {
51    low: X,
52    scale: X,
53}
54
55macro_rules! uniform_float_impl {
56    ($($meta:meta)?, $ty:ty, $uty:ident, $f_scalar:ident, $u_scalar:ident, $bits_to_discard:expr) => {
57        $(#[cfg($meta)])?
58        impl UniformFloat<$ty> {
59            /// Construct, reducing `scale` as required to ensure that rounding
60            /// can never yield values greater than `high`.
61            ///
62            /// Note: though it may be tempting to use a variant of this method
63            /// to ensure that samples from `[low, high)` are always strictly
64            /// less than `high`, this approach may be very slow where
65            /// `scale.abs()` is much smaller than `high.abs()`
66            /// (example: `low=0.99999999997819644, high=1.`).
67            fn new_bounded(low: $ty, high: $ty, mut scale: $ty) -> Self {
68                let max_rand = <$ty>::splat(1.0 as $f_scalar - $f_scalar::EPSILON);
69
70                loop {
71                    let mask = (scale * max_rand + low).gt_mask(high);
72                    if !mask.any() {
73                        break;
74                    }
75                    scale = scale.decrease_masked(mask);
76                }
77
78                debug_assert!(<$ty>::splat(0.0).all_le(scale));
79
80                UniformFloat { low, scale }
81            }
82        }
83
84        $(#[cfg($meta)])?
85        impl SampleUniform for $ty {
86            type Sampler = UniformFloat<$ty>;
87        }
88
89        $(#[cfg($meta)])?
90        impl UniformSampler for UniformFloat<$ty> {
91            type X = $ty;
92
93            fn new<B1, B2>(low_b: B1, high_b: B2) -> Result<Self, Error>
94            where
95                B1: SampleBorrow<Self::X> + Sized,
96                B2: SampleBorrow<Self::X> + Sized,
97            {
98                let low = *low_b.borrow();
99                let high = *high_b.borrow();
100                #[cfg(debug_assertions)]
101                if !(low.all_finite()) || !(high.all_finite()) {
102                    return Err(Error::NonFinite);
103                }
104                if !(low.all_lt(high)) {
105                    return Err(Error::EmptyRange);
106                }
107
108                let scale = high - low;
109                if !(scale.all_finite()) {
110                    return Err(Error::NonFinite);
111                }
112
113                Ok(Self::new_bounded(low, high, scale))
114            }
115
116            fn new_inclusive<B1, B2>(low_b: B1, high_b: B2) -> Result<Self, Error>
117            where
118                B1: SampleBorrow<Self::X> + Sized,
119                B2: SampleBorrow<Self::X> + Sized,
120            {
121                let low = *low_b.borrow();
122                let high = *high_b.borrow();
123                #[cfg(debug_assertions)]
124                if !(low.all_finite()) || !(high.all_finite()) {
125                    return Err(Error::NonFinite);
126                }
127                if !low.all_le(high) {
128                    return Err(Error::EmptyRange);
129                }
130
131                let max_rand = <$ty>::splat(1.0 as $f_scalar - $f_scalar::EPSILON);
132                let scale = (high - low) / max_rand;
133                if !scale.all_finite() {
134                    return Err(Error::NonFinite);
135                }
136
137                Ok(Self::new_bounded(low, high, scale))
138            }
139
140            fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Self::X {
141                // Generate a value in the range [1, 2)
142                let value1_2 = (rng.random::<$uty>() >> $uty::splat($bits_to_discard)).into_float_with_exponent(0);
143
144                // Get a value in the range [0, 1) to avoid overflow when multiplying by scale
145                let value0_1 = value1_2 - <$ty>::splat(1.0);
146
147                // We don't use `f64::mul_add`, because it is not available with
148                // `no_std`. Furthermore, it is slower for some targets (but
149                // faster for others). However, the order of multiplication and
150                // addition is important, because on some platforms (e.g. ARM)
151                // it will be optimized to a single (non-FMA) instruction.
152                value0_1 * self.scale + self.low
153            }
154
155            #[inline]
156            fn sample_single<R: Rng + ?Sized, B1, B2>(low_b: B1, high_b: B2, rng: &mut R) -> Result<Self::X, Error>
157            where
158                B1: SampleBorrow<Self::X> + Sized,
159                B2: SampleBorrow<Self::X> + Sized,
160            {
161                Self::sample_single_inclusive(low_b, high_b, rng)
162            }
163
164            #[inline]
165            fn sample_single_inclusive<R: Rng + ?Sized, B1, B2>(low_b: B1, high_b: B2, rng: &mut R) -> Result<Self::X, Error>
166            where
167                B1: SampleBorrow<Self::X> + Sized,
168                B2: SampleBorrow<Self::X> + Sized,
169            {
170                let low = *low_b.borrow();
171                let high = *high_b.borrow();
172                #[cfg(debug_assertions)]
173                if !low.all_finite() || !high.all_finite() {
174                    return Err(Error::NonFinite);
175                }
176                if !low.all_le(high) {
177                    return Err(Error::EmptyRange);
178                }
179                let scale = high - low;
180                if !scale.all_finite() {
181                    return Err(Error::NonFinite);
182                }
183
184                // Generate a value in the range [1, 2)
185                let value1_2 =
186                    (rng.random::<$uty>() >> $uty::splat($bits_to_discard)).into_float_with_exponent(0);
187
188                // Get a value in the range [0, 1) to avoid overflow when multiplying by scale
189                let value0_1 = value1_2 - <$ty>::splat(1.0);
190
191                // Doing multiply before addition allows some architectures
192                // to use a single instruction.
193                Ok(value0_1 * scale + low)
194            }
195        }
196    };
197}
198
199uniform_float_impl! { , f32, u32, f32, u32, 32 - 23 }
200uniform_float_impl! { , f64, u64, f64, u64, 64 - 52 }
201
202#[cfg(feature = "simd_support")]
203uniform_float_impl! { feature = "simd_support", f32x2, u32x2, f32, u32, 32 - 23 }
204#[cfg(feature = "simd_support")]
205uniform_float_impl! { feature = "simd_support", f32x4, u32x4, f32, u32, 32 - 23 }
206#[cfg(feature = "simd_support")]
207uniform_float_impl! { feature = "simd_support", f32x8, u32x8, f32, u32, 32 - 23 }
208#[cfg(feature = "simd_support")]
209uniform_float_impl! { feature = "simd_support", f32x16, u32x16, f32, u32, 32 - 23 }
210
211#[cfg(feature = "simd_support")]
212uniform_float_impl! { feature = "simd_support", f64x2, u64x2, f64, u64, 64 - 52 }
213#[cfg(feature = "simd_support")]
214uniform_float_impl! { feature = "simd_support", f64x4, u64x4, f64, u64, 64 - 52 }
215#[cfg(feature = "simd_support")]
216uniform_float_impl! { feature = "simd_support", f64x8, u64x8, f64, u64, 64 - 52 }
217
218#[cfg(test)]
219mod tests {
220    use super::*;
221    use crate::distr::{utils::FloatSIMDScalarUtils, Uniform};
222    use crate::rngs::mock::StepRng;
223
224    #[test]
225    #[cfg_attr(miri, ignore)] // Miri is too slow
226    fn test_floats() {
227        let mut rng = crate::test::rng(252);
228        let mut zero_rng = StepRng::new(0, 0);
229        let mut max_rng = StepRng::new(0xffff_ffff_ffff_ffff, 0);
230        macro_rules! t {
231            ($ty:ty, $f_scalar:ident, $bits_shifted:expr) => {{
232                let v: &[($f_scalar, $f_scalar)] = &[
233                    (0.0, 100.0),
234                    (-1e35, -1e25),
235                    (1e-35, 1e-25),
236                    (-1e35, 1e35),
237                    (<$f_scalar>::from_bits(0), <$f_scalar>::from_bits(3)),
238                    (-<$f_scalar>::from_bits(10), -<$f_scalar>::from_bits(1)),
239                    (-<$f_scalar>::from_bits(5), 0.0),
240                    (-<$f_scalar>::from_bits(7), -0.0),
241                    (0.1 * $f_scalar::MAX, $f_scalar::MAX),
242                    (-$f_scalar::MAX * 0.2, $f_scalar::MAX * 0.7),
243                ];
244                for &(low_scalar, high_scalar) in v.iter() {
245                    for lane in 0..<$ty>::LEN {
246                        let low = <$ty>::splat(0.0 as $f_scalar).replace(lane, low_scalar);
247                        let high = <$ty>::splat(1.0 as $f_scalar).replace(lane, high_scalar);
248                        let my_uniform = Uniform::new(low, high).unwrap();
249                        let my_incl_uniform = Uniform::new_inclusive(low, high).unwrap();
250                        for _ in 0..100 {
251                            let v = rng.sample(my_uniform).extract_lane(lane);
252                            assert!(low_scalar <= v && v <= high_scalar);
253                            let v = rng.sample(my_incl_uniform).extract_lane(lane);
254                            assert!(low_scalar <= v && v <= high_scalar);
255                            let v =
256                                <$ty as SampleUniform>::Sampler::sample_single(low, high, &mut rng)
257                                    .unwrap()
258                                    .extract_lane(lane);
259                            assert!(low_scalar <= v && v <= high_scalar);
260                            let v = <$ty as SampleUniform>::Sampler::sample_single_inclusive(
261                                low, high, &mut rng,
262                            )
263                            .unwrap()
264                            .extract_lane(lane);
265                            assert!(low_scalar <= v && v <= high_scalar);
266                        }
267
268                        assert_eq!(
269                            rng.sample(Uniform::new_inclusive(low, low).unwrap())
270                                .extract_lane(lane),
271                            low_scalar
272                        );
273
274                        assert_eq!(zero_rng.sample(my_uniform).extract_lane(lane), low_scalar);
275                        assert_eq!(
276                            zero_rng.sample(my_incl_uniform).extract_lane(lane),
277                            low_scalar
278                        );
279                        assert_eq!(
280                            <$ty as SampleUniform>::Sampler::sample_single(
281                                low,
282                                high,
283                                &mut zero_rng
284                            )
285                            .unwrap()
286                            .extract_lane(lane),
287                            low_scalar
288                        );
289                        assert_eq!(
290                            <$ty as SampleUniform>::Sampler::sample_single_inclusive(
291                                low,
292                                high,
293                                &mut zero_rng
294                            )
295                            .unwrap()
296                            .extract_lane(lane),
297                            low_scalar
298                        );
299
300                        assert!(max_rng.sample(my_uniform).extract_lane(lane) <= high_scalar);
301                        assert!(max_rng.sample(my_incl_uniform).extract_lane(lane) <= high_scalar);
302                        // sample_single cannot cope with max_rng:
303                        // assert!(<$ty as SampleUniform>::Sampler
304                        //     ::sample_single(low, high, &mut max_rng).unwrap()
305                        //     .extract(lane) <= high_scalar);
306                        assert!(
307                            <$ty as SampleUniform>::Sampler::sample_single_inclusive(
308                                low,
309                                high,
310                                &mut max_rng
311                            )
312                            .unwrap()
313                            .extract_lane(lane)
314                                <= high_scalar
315                        );
316
317                        // Don't run this test for really tiny differences between high and low
318                        // since for those rounding might result in selecting high for a very
319                        // long time.
320                        if (high_scalar - low_scalar) > 0.0001 {
321                            let mut lowering_max_rng = StepRng::new(
322                                0xffff_ffff_ffff_ffff,
323                                (-1i64 << $bits_shifted) as u64,
324                            );
325                            assert!(
326                                <$ty as SampleUniform>::Sampler::sample_single(
327                                    low,
328                                    high,
329                                    &mut lowering_max_rng
330                                )
331                                .unwrap()
332                                .extract_lane(lane)
333                                    <= high_scalar
334                            );
335                        }
336                    }
337                }
338
339                assert_eq!(
340                    rng.sample(Uniform::new_inclusive($f_scalar::MAX, $f_scalar::MAX).unwrap()),
341                    $f_scalar::MAX
342                );
343                assert_eq!(
344                    rng.sample(Uniform::new_inclusive(-$f_scalar::MAX, -$f_scalar::MAX).unwrap()),
345                    -$f_scalar::MAX
346                );
347            }};
348        }
349
350        t!(f32, f32, 32 - 23);
351        t!(f64, f64, 64 - 52);
352        #[cfg(feature = "simd_support")]
353        {
354            t!(f32x2, f32, 32 - 23);
355            t!(f32x4, f32, 32 - 23);
356            t!(f32x8, f32, 32 - 23);
357            t!(f32x16, f32, 32 - 23);
358            t!(f64x2, f64, 64 - 52);
359            t!(f64x4, f64, 64 - 52);
360            t!(f64x8, f64, 64 - 52);
361        }
362    }
363
364    #[test]
365    fn test_float_overflow() {
366        assert_eq!(Uniform::try_from(f64::MIN..f64::MAX), Err(Error::NonFinite));
367    }
368
369    #[test]
370    #[should_panic]
371    fn test_float_overflow_single() {
372        let mut rng = crate::test::rng(252);
373        rng.random_range(f64::MIN..f64::MAX);
374    }
375
376    #[test]
377    #[cfg(all(feature = "std", panic = "unwind"))]
378    fn test_float_assertions() {
379        use super::SampleUniform;
380        fn range<T: SampleUniform>(low: T, high: T) -> Result<T, Error> {
381            let mut rng = crate::test::rng(253);
382            T::Sampler::sample_single(low, high, &mut rng)
383        }
384
385        macro_rules! t {
386            ($ty:ident, $f_scalar:ident) => {{
387                let v: &[($f_scalar, $f_scalar)] = &[
388                    ($f_scalar::NAN, 0.0),
389                    (1.0, $f_scalar::NAN),
390                    ($f_scalar::NAN, $f_scalar::NAN),
391                    (1.0, 0.5),
392                    ($f_scalar::MAX, -$f_scalar::MAX),
393                    ($f_scalar::INFINITY, $f_scalar::INFINITY),
394                    ($f_scalar::NEG_INFINITY, $f_scalar::NEG_INFINITY),
395                    ($f_scalar::NEG_INFINITY, 5.0),
396                    (5.0, $f_scalar::INFINITY),
397                    ($f_scalar::NAN, $f_scalar::INFINITY),
398                    ($f_scalar::NEG_INFINITY, $f_scalar::NAN),
399                    ($f_scalar::NEG_INFINITY, $f_scalar::INFINITY),
400                ];
401                for &(low_scalar, high_scalar) in v.iter() {
402                    for lane in 0..<$ty>::LEN {
403                        let low = <$ty>::splat(0.0 as $f_scalar).replace(lane, low_scalar);
404                        let high = <$ty>::splat(1.0 as $f_scalar).replace(lane, high_scalar);
405                        assert!(range(low, high).is_err());
406                        assert!(Uniform::new(low, high).is_err());
407                        assert!(Uniform::new_inclusive(low, high).is_err());
408                        assert!(Uniform::new(low, low).is_err());
409                    }
410                }
411            }};
412        }
413
414        t!(f32, f32);
415        t!(f64, f64);
416        #[cfg(feature = "simd_support")]
417        {
418            t!(f32x2, f32);
419            t!(f32x4, f32);
420            t!(f32x8, f32);
421            t!(f32x16, f32);
422            t!(f64x2, f64);
423            t!(f64x4, f64);
424            t!(f64x8, f64);
425        }
426    }
427
428    #[test]
429    fn test_uniform_from_std_range() {
430        let r = Uniform::try_from(2.0f64..7.0).unwrap();
431        assert_eq!(r.0.low, 2.0);
432        assert_eq!(r.0.scale, 5.0);
433    }
434
435    #[test]
436    fn test_uniform_from_std_range_bad_limits() {
437        #![allow(clippy::reversed_empty_ranges)]
438        assert!(Uniform::try_from(100.0..10.0).is_err());
439        assert!(Uniform::try_from(100.0..100.0).is_err());
440    }
441
442    #[test]
443    fn test_uniform_from_std_range_inclusive() {
444        let r = Uniform::try_from(2.0f64..=7.0).unwrap();
445        assert_eq!(r.0.low, 2.0);
446        assert!(r.0.scale > 5.0);
447        assert!(r.0.scale < 5.0 + 1e-14);
448    }
449
450    #[test]
451    fn test_uniform_from_std_range_inclusive_bad_limits() {
452        #![allow(clippy::reversed_empty_ranges)]
453        assert!(Uniform::try_from(100.0..=10.0).is_err());
454        assert!(Uniform::try_from(100.0..=99.0).is_err());
455    }
456}