rand/distr/weighted/
weighted_index.rs

1// Copyright 2018 Developers of the Rand project.
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
6// option. This file may not be copied, modified, or distributed
7// except according to those terms.
8
9use super::{Error, Weight};
10use crate::Rng;
11use crate::distr::Distribution;
12use crate::distr::uniform::{SampleBorrow, SampleUniform, UniformSampler};
13
14// Note that this whole module is only imported if feature="alloc" is enabled.
15use alloc::vec::Vec;
16use core::fmt::{self, Debug};
17
18#[cfg(feature = "serde")]
19use serde::{Deserialize, Serialize};
20
21/// A distribution using weighted sampling of discrete items.
22///
23/// Sampling a `WeightedIndex` distribution returns the index of a randomly
24/// selected element from the iterator used when the `WeightedIndex` was
25/// created. The chance of a given element being picked is proportional to the
26/// weight of the element. The weights can use any type `X` for which an
27/// implementation of [`Uniform<X>`] exists. The implementation guarantees that
28/// elements with zero weight are never picked, even when the weights are
29/// floating point numbers.
30///
31/// # Performance
32///
33/// Time complexity of sampling from `WeightedIndex` is `O(log N)` where
34/// `N` is the number of weights.
35/// See also [`rand_distr::weighted`] for alternative implementations supporting
36/// potentially-faster sampling or a more easily modifiable tree structure.
37///
38/// A `WeightedIndex<X>` contains a `Vec<X>` and a [`Uniform<X>`] and so its
39/// size is the sum of the size of those objects, possibly plus some alignment.
40///
41/// Creating a `WeightedIndex<X>` will allocate enough space to hold `N - 1`
42/// weights of type `X`, where `N` is the number of weights. However, since
43/// `Vec` doesn't guarantee a particular growth strategy, additional memory
44/// might be allocated but not used. Since the `WeightedIndex` object also
45/// contains an instance of `X::Sampler`, this might cause additional allocations,
46/// though for primitive types, [`Uniform<X>`] doesn't allocate any memory.
47///
48/// Sampling from `WeightedIndex` will result in a single call to
49/// `Uniform<X>::sample` (method of the [`Distribution`] trait), which typically
50/// will request a single value from the underlying [`Rng`], though the
51/// exact number depends on the implementation of `Uniform<X>::sample`.
52///
53/// # Example
54///
55/// ```
56/// use rand::prelude::*;
57/// use rand::distr::weighted::WeightedIndex;
58///
59/// let choices = ['a', 'b', 'c'];
60/// let weights = [2,   1,   1];
61/// let dist = WeightedIndex::new(&weights).unwrap();
62/// let mut rng = rand::rng();
63/// for _ in 0..100 {
64///     // 50% chance to print 'a', 25% chance to print 'b', 25% chance to print 'c'
65///     println!("{}", choices[dist.sample(&mut rng)]);
66/// }
67///
68/// let items = [('a', 0.0), ('b', 3.0), ('c', 7.0)];
69/// let dist2 = WeightedIndex::new(items.iter().map(|item| item.1)).unwrap();
70/// for _ in 0..100 {
71///     // 0% chance to print 'a', 30% chance to print 'b', 70% chance to print 'c'
72///     println!("{}", items[dist2.sample(&mut rng)].0);
73/// }
74/// ```
75///
76/// [`Uniform<X>`]: crate::distr::Uniform
77/// [`Rng`]: crate::Rng
78/// [`rand_distr::weighted`]: https://docs.rs/rand_distr/latest/rand_distr/weighted/index.html
79#[derive(Debug, Clone, PartialEq)]
80#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
81pub struct WeightedIndex<X: SampleUniform + PartialOrd> {
82    cumulative_weights: Vec<X>,
83    total_weight: X,
84    weight_distribution: X::Sampler,
85}
86
87impl<X: SampleUniform + PartialOrd> WeightedIndex<X> {
88    /// Creates a new a `WeightedIndex` [`Distribution`] using the values
89    /// in `weights`. The weights can use any type `X` for which an
90    /// implementation of [`Uniform<X>`] exists.
91    ///
92    /// Error cases:
93    /// -   [`Error::InvalidInput`] when the iterator `weights` is empty.
94    /// -   [`Error::InvalidWeight`] when a weight is not-a-number or negative.
95    /// -   [`Error::InsufficientNonZero`] when the sum of all weights is zero.
96    /// -   [`Error::Overflow`] when the sum of all weights overflows.
97    ///
98    /// [`Uniform<X>`]: crate::distr::uniform::Uniform
99    pub fn new<I>(weights: I) -> Result<WeightedIndex<X>, Error>
100    where
101        I: IntoIterator,
102        I::Item: SampleBorrow<X>,
103        X: Weight,
104    {
105        let mut iter = weights.into_iter();
106        let mut total_weight: X = iter.next().ok_or(Error::InvalidInput)?.borrow().clone();
107
108        let zero = X::ZERO;
109        if !(total_weight >= zero) {
110            return Err(Error::InvalidWeight);
111        }
112
113        let mut weights = Vec::<X>::with_capacity(iter.size_hint().0);
114        for w in iter {
115            // Note that `!(w >= x)` is not equivalent to `w < x` for partially
116            // ordered types due to NaNs which are equal to nothing.
117            if !(w.borrow() >= &zero) {
118                return Err(Error::InvalidWeight);
119            }
120            weights.push(total_weight.clone());
121
122            if let Err(()) = total_weight.checked_add_assign(w.borrow()) {
123                return Err(Error::Overflow);
124            }
125        }
126
127        if total_weight == zero {
128            return Err(Error::InsufficientNonZero);
129        }
130        let distr = X::Sampler::new(zero, total_weight.clone()).unwrap();
131
132        Ok(WeightedIndex {
133            cumulative_weights: weights,
134            total_weight,
135            weight_distribution: distr,
136        })
137    }
138
139    /// Update a subset of weights, without changing the number of weights.
140    ///
141    /// `new_weights` must be sorted by the index.
142    ///
143    /// Using this method instead of `new` might be more efficient if only a small number of
144    /// weights is modified. No allocations are performed, unless the weight type `X` uses
145    /// allocation internally.
146    ///
147    /// In case of error, `self` is not modified. Error cases:
148    /// -   [`Error::InvalidInput`] when `new_weights` are not ordered by
149    ///     index or an index is too large.
150    /// -   [`Error::InvalidWeight`] when a weight is not-a-number or negative.
151    /// -   [`Error::InsufficientNonZero`] when the sum of all weights is zero.
152    ///     Note that due to floating-point loss of precision, this case is not
153    ///     always correctly detected; usage of a fixed-point weight type may be
154    ///     preferred.
155    ///
156    /// Updates take `O(N)` time. If you need to frequently update weights, consider
157    /// [`rand_distr::weighted_tree`](https://docs.rs/rand_distr/*/rand_distr/weighted_tree/index.html)
158    /// as an alternative where an update is `O(log N)`.
159    pub fn update_weights(&mut self, new_weights: &[(usize, &X)]) -> Result<(), Error>
160    where
161        X: for<'a> core::ops::AddAssign<&'a X>
162            + for<'a> core::ops::SubAssign<&'a X>
163            + Clone
164            + Default,
165    {
166        if new_weights.is_empty() {
167            return Ok(());
168        }
169
170        let zero = <X as Default>::default();
171
172        let mut total_weight = self.total_weight.clone();
173
174        // Check for errors first, so we don't modify `self` in case something
175        // goes wrong.
176        let mut prev_i = None;
177        for &(i, w) in new_weights {
178            if let Some(old_i) = prev_i {
179                if old_i >= i {
180                    return Err(Error::InvalidInput);
181                }
182            }
183            if !(*w >= zero) {
184                return Err(Error::InvalidWeight);
185            }
186            if i > self.cumulative_weights.len() {
187                return Err(Error::InvalidInput);
188            }
189
190            let mut old_w = if i < self.cumulative_weights.len() {
191                self.cumulative_weights[i].clone()
192            } else {
193                self.total_weight.clone()
194            };
195            if i > 0 {
196                old_w -= &self.cumulative_weights[i - 1];
197            }
198
199            total_weight -= &old_w;
200            total_weight += w;
201            prev_i = Some(i);
202        }
203        if total_weight <= zero {
204            return Err(Error::InsufficientNonZero);
205        }
206
207        // Update the weights. Because we checked all the preconditions in the
208        // previous loop, this should never panic.
209        let mut iter = new_weights.iter();
210
211        let mut prev_weight = zero.clone();
212        let mut next_new_weight = iter.next();
213        let &(first_new_index, _) = next_new_weight.unwrap();
214        let mut cumulative_weight = if first_new_index > 0 {
215            self.cumulative_weights[first_new_index - 1].clone()
216        } else {
217            zero.clone()
218        };
219        for i in first_new_index..self.cumulative_weights.len() {
220            match next_new_weight {
221                Some(&(j, w)) if i == j => {
222                    cumulative_weight += w;
223                    next_new_weight = iter.next();
224                }
225                _ => {
226                    let mut tmp = self.cumulative_weights[i].clone();
227                    tmp -= &prev_weight; // We know this is positive.
228                    cumulative_weight += &tmp;
229                }
230            }
231            prev_weight = cumulative_weight.clone();
232            core::mem::swap(&mut prev_weight, &mut self.cumulative_weights[i]);
233        }
234
235        self.total_weight = total_weight;
236        self.weight_distribution = X::Sampler::new(zero, self.total_weight.clone()).unwrap();
237
238        Ok(())
239    }
240}
241
242/// A lazy-loading iterator over the weights of a `WeightedIndex` distribution.
243/// This is returned by [`WeightedIndex::weights`].
244pub struct WeightedIndexIter<'a, X: SampleUniform + PartialOrd> {
245    weighted_index: &'a WeightedIndex<X>,
246    index: usize,
247}
248
249impl<X> Debug for WeightedIndexIter<'_, X>
250where
251    X: SampleUniform + PartialOrd + Debug,
252    X::Sampler: Debug,
253{
254    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
255        f.debug_struct("WeightedIndexIter")
256            .field("weighted_index", &self.weighted_index)
257            .field("index", &self.index)
258            .finish()
259    }
260}
261
262impl<X> Clone for WeightedIndexIter<'_, X>
263where
264    X: SampleUniform + PartialOrd,
265{
266    fn clone(&self) -> Self {
267        WeightedIndexIter {
268            weighted_index: self.weighted_index,
269            index: self.index,
270        }
271    }
272}
273
274impl<X> Iterator for WeightedIndexIter<'_, X>
275where
276    X: for<'b> core::ops::SubAssign<&'b X> + SampleUniform + PartialOrd + Clone,
277{
278    type Item = X;
279
280    fn next(&mut self) -> Option<Self::Item> {
281        match self.weighted_index.weight(self.index) {
282            None => None,
283            Some(weight) => {
284                self.index += 1;
285                Some(weight)
286            }
287        }
288    }
289}
290
291impl<X: SampleUniform + PartialOrd + Clone> WeightedIndex<X> {
292    /// Returns the weight at the given index, if it exists.
293    ///
294    /// If the index is out of bounds, this will return `None`.
295    ///
296    /// # Example
297    ///
298    /// ```
299    /// use rand::distr::weighted::WeightedIndex;
300    ///
301    /// let weights = [0, 1, 2];
302    /// let dist = WeightedIndex::new(&weights).unwrap();
303    /// assert_eq!(dist.weight(0), Some(0));
304    /// assert_eq!(dist.weight(1), Some(1));
305    /// assert_eq!(dist.weight(2), Some(2));
306    /// assert_eq!(dist.weight(3), None);
307    /// ```
308    pub fn weight(&self, index: usize) -> Option<X>
309    where
310        X: for<'a> core::ops::SubAssign<&'a X>,
311    {
312        use core::cmp::Ordering::*;
313
314        let mut weight = match index.cmp(&self.cumulative_weights.len()) {
315            Less => self.cumulative_weights[index].clone(),
316            Equal => self.total_weight.clone(),
317            Greater => return None,
318        };
319
320        if index > 0 {
321            weight -= &self.cumulative_weights[index - 1];
322        }
323        Some(weight)
324    }
325
326    /// Returns a lazy-loading iterator containing the current weights of this distribution.
327    ///
328    /// If this distribution has not been updated since its creation, this will return the
329    /// same weights as were passed to `new`.
330    ///
331    /// # Example
332    ///
333    /// ```
334    /// use rand::distr::weighted::WeightedIndex;
335    ///
336    /// let weights = [1, 2, 3];
337    /// let mut dist = WeightedIndex::new(&weights).unwrap();
338    /// assert_eq!(dist.weights().collect::<Vec<_>>(), vec![1, 2, 3]);
339    /// dist.update_weights(&[(0, &2)]).unwrap();
340    /// assert_eq!(dist.weights().collect::<Vec<_>>(), vec![2, 2, 3]);
341    /// ```
342    pub fn weights(&self) -> WeightedIndexIter<'_, X>
343    where
344        X: for<'a> core::ops::SubAssign<&'a X>,
345    {
346        WeightedIndexIter {
347            weighted_index: self,
348            index: 0,
349        }
350    }
351
352    /// Returns the sum of all weights in this distribution.
353    pub fn total_weight(&self) -> X {
354        self.total_weight.clone()
355    }
356}
357
358impl<X> Distribution<usize> for WeightedIndex<X>
359where
360    X: SampleUniform + PartialOrd,
361{
362    fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> usize {
363        let chosen_weight = self.weight_distribution.sample(rng);
364        // Find the first item which has a weight *higher* than the chosen weight.
365        self.cumulative_weights
366            .partition_point(|w| w <= &chosen_weight)
367    }
368}
369
370#[cfg(test)]
371mod test {
372    use super::*;
373    use crate::RngExt;
374
375    #[cfg(feature = "serde")]
376    #[test]
377    fn test_weightedindex_serde() {
378        let weighted_index = WeightedIndex::new([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]).unwrap();
379
380        let ser_weighted_index = postcard::to_allocvec(&weighted_index).unwrap();
381        let de_weighted_index: WeightedIndex<i32> =
382            postcard::from_bytes(&ser_weighted_index).unwrap();
383
384        assert_eq!(
385            de_weighted_index.cumulative_weights,
386            weighted_index.cumulative_weights
387        );
388        assert_eq!(de_weighted_index.total_weight, weighted_index.total_weight);
389    }
390
391    #[test]
392    fn test_accepting_nan() {
393        assert_eq!(
394            WeightedIndex::new([f32::NAN, 0.5]).unwrap_err(),
395            Error::InvalidWeight,
396        );
397        assert_eq!(
398            WeightedIndex::new([f32::NAN]).unwrap_err(),
399            Error::InvalidWeight,
400        );
401        assert_eq!(
402            WeightedIndex::new([0.5, f32::NAN]).unwrap_err(),
403            Error::InvalidWeight,
404        );
405
406        assert_eq!(
407            WeightedIndex::new([0.5, 7.0])
408                .unwrap()
409                .update_weights(&[(0, &f32::NAN)])
410                .unwrap_err(),
411            Error::InvalidWeight,
412        )
413    }
414
415    #[test]
416    #[cfg_attr(miri, ignore)] // Miri is too slow
417    fn test_weightedindex() {
418        let mut r = crate::test::rng(700);
419        const N_REPS: u32 = 5000;
420        let weights = [1u32, 2, 3, 0, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7];
421        let total_weight = weights.iter().sum::<u32>() as f32;
422
423        let verify = |result: [i32; 14]| {
424            for (i, count) in result.iter().enumerate() {
425                let exp = (weights[i] * N_REPS) as f32 / total_weight;
426                let mut err = (*count as f32 - exp).abs();
427                if err != 0.0 {
428                    err /= exp;
429                }
430                assert!(err <= 0.25);
431            }
432        };
433
434        // WeightedIndex from vec
435        let mut chosen = [0i32; 14];
436        let distr = WeightedIndex::new(weights.to_vec()).unwrap();
437        for _ in 0..N_REPS {
438            chosen[distr.sample(&mut r)] += 1;
439        }
440        verify(chosen);
441
442        // WeightedIndex from slice
443        chosen = [0i32; 14];
444        let distr = WeightedIndex::new(&weights[..]).unwrap();
445        for _ in 0..N_REPS {
446            chosen[distr.sample(&mut r)] += 1;
447        }
448        verify(chosen);
449
450        // WeightedIndex from iterator
451        chosen = [0i32; 14];
452        let distr = WeightedIndex::new(weights.iter()).unwrap();
453        for _ in 0..N_REPS {
454            chosen[distr.sample(&mut r)] += 1;
455        }
456        verify(chosen);
457
458        for _ in 0..5 {
459            assert_eq!(WeightedIndex::new([0, 1]).unwrap().sample(&mut r), 1);
460            assert_eq!(WeightedIndex::new([1, 0]).unwrap().sample(&mut r), 0);
461            assert_eq!(
462                WeightedIndex::new([0, 0, 0, 0, 10, 0])
463                    .unwrap()
464                    .sample(&mut r),
465                4
466            );
467        }
468
469        assert_eq!(
470            WeightedIndex::new(&[10][0..0]).unwrap_err(),
471            Error::InvalidInput
472        );
473        assert_eq!(
474            WeightedIndex::new([0]).unwrap_err(),
475            Error::InsufficientNonZero
476        );
477        assert_eq!(
478            WeightedIndex::new([10, 20, -1, 30]).unwrap_err(),
479            Error::InvalidWeight
480        );
481        assert_eq!(
482            WeightedIndex::new([-10, 20, 1, 30]).unwrap_err(),
483            Error::InvalidWeight
484        );
485        assert_eq!(WeightedIndex::new([-10]).unwrap_err(), Error::InvalidWeight);
486    }
487
488    #[test]
489    fn test_update_weights() {
490        let data = [
491            (
492                &[10u32, 2, 3, 4][..],
493                &[(1, &100), (2, &4)][..], // positive change
494                &[10, 100, 4, 4][..],
495            ),
496            (
497                &[1u32, 2, 3, 0, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7][..],
498                &[(2, &1), (5, &1), (13, &100)][..], // negative change and last element
499                &[1u32, 2, 1, 0, 5, 1, 7, 1, 2, 3, 4, 5, 6, 100][..],
500            ),
501        ];
502
503        for (weights, update, expected_weights) in data.iter() {
504            let total_weight = weights.iter().sum::<u32>();
505            let mut distr = WeightedIndex::new(weights.to_vec()).unwrap();
506            assert_eq!(distr.total_weight, total_weight);
507
508            distr.update_weights(update).unwrap();
509            let expected_total_weight = expected_weights.iter().sum::<u32>();
510            let expected_distr = WeightedIndex::new(expected_weights.to_vec()).unwrap();
511            assert_eq!(distr.total_weight, expected_total_weight);
512            assert_eq!(distr.total_weight, expected_distr.total_weight);
513            assert_eq!(distr.cumulative_weights, expected_distr.cumulative_weights);
514        }
515    }
516
517    #[test]
518    fn test_update_weights_errors() {
519        let data = [
520            (
521                &[1i32, 0, 0][..],
522                &[(0, &0)][..],
523                Error::InsufficientNonZero,
524            ),
525            (
526                &[10, 10, 10, 10][..],
527                &[(1, &-11)][..],
528                Error::InvalidWeight, // A weight is negative
529            ),
530            (
531                &[1, 2, 3, 4, 5][..],
532                &[(1, &5), (0, &5)][..], // Wrong order
533                Error::InvalidInput,
534            ),
535            (
536                &[1][..],
537                &[(1, &1)][..], // Index too large
538                Error::InvalidInput,
539            ),
540        ];
541
542        for (weights, update, err) in data.iter() {
543            let total_weight = weights.iter().sum::<i32>();
544            let mut distr = WeightedIndex::new(weights.to_vec()).unwrap();
545            assert_eq!(distr.total_weight, total_weight);
546            match distr.update_weights(update) {
547                Ok(_) => panic!("Expected update_weights to fail, but it succeeded"),
548                Err(e) => assert_eq!(e, *err),
549            }
550        }
551    }
552
553    #[test]
554    fn test_weight_at() {
555        let data = [
556            &[1][..],
557            &[10, 2, 3, 4][..],
558            &[1, 2, 3, 0, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7][..],
559            &[u32::MAX][..],
560        ];
561
562        for weights in data.iter() {
563            let distr = WeightedIndex::new(weights.to_vec()).unwrap();
564            for (i, weight) in weights.iter().enumerate() {
565                assert_eq!(distr.weight(i), Some(*weight));
566            }
567            assert_eq!(distr.weight(weights.len()), None);
568        }
569    }
570
571    #[test]
572    fn test_weights() {
573        let data = [
574            &[1][..],
575            &[10, 2, 3, 4][..],
576            &[1, 2, 3, 0, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7][..],
577            &[u32::MAX][..],
578        ];
579
580        for weights in data.iter() {
581            let distr = WeightedIndex::new(weights.to_vec()).unwrap();
582            assert_eq!(distr.weights().collect::<Vec<_>>(), weights.to_vec());
583        }
584    }
585
586    #[test]
587    fn value_stability() {
588        fn test_samples<X: Weight + SampleUniform + PartialOrd, I>(
589            weights: I,
590            buf: &mut [usize],
591            expected: &[usize],
592        ) where
593            I: IntoIterator,
594            I::Item: SampleBorrow<X>,
595        {
596            assert_eq!(buf.len(), expected.len());
597            let distr = WeightedIndex::new(weights).unwrap();
598            let mut rng = crate::test::rng(701);
599            for r in buf.iter_mut() {
600                *r = rng.sample(&distr);
601            }
602            assert_eq!(buf, expected);
603        }
604
605        let mut buf = [0; 10];
606        test_samples(
607            [1i32, 1, 1, 1, 1, 1, 1, 1, 1],
608            &mut buf,
609            &[0, 6, 2, 6, 3, 4, 7, 8, 2, 5],
610        );
611        test_samples(
612            [0.7f32, 0.1, 0.1, 0.1],
613            &mut buf,
614            &[0, 0, 0, 1, 0, 0, 2, 3, 0, 0],
615        );
616        test_samples(
617            [1.0f64, 0.999, 0.998, 0.997],
618            &mut buf,
619            &[2, 2, 1, 3, 2, 1, 3, 3, 2, 1],
620        );
621    }
622
623    #[test]
624    fn weighted_index_distributions_can_be_compared() {
625        assert_eq!(WeightedIndex::new([1, 2]), WeightedIndex::new([1, 2]));
626    }
627
628    #[test]
629    fn overflow() {
630        assert_eq!(WeightedIndex::new([2, usize::MAX]), Err(Error::Overflow));
631    }
632}