rand/seq/
slice.rs

1// Copyright 2018-2023 Developers of the Rand project.
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
6// option. This file may not be copied, modified, or distributed
7// except according to those terms.
8
9//! `IndexedRandom`, `IndexedMutRandom`, `SliceRandom`
10
11use super::increasing_uniform::IncreasingUniform;
12use super::index;
13#[cfg(feature = "alloc")]
14use crate::distr::uniform::{SampleBorrow, SampleUniform};
15#[cfg(feature = "alloc")]
16use crate::distr::weighted::{Error as WeightError, Weight};
17use crate::Rng;
18use core::ops::{Index, IndexMut};
19
20/// Extension trait on indexable lists, providing random sampling methods.
21///
22/// This trait is implemented on `[T]` slice types. Other types supporting
23/// [`std::ops::Index<usize>`] may implement this (only [`Self::len`] must be
24/// specified).
25pub trait IndexedRandom: Index<usize> {
26    /// The length
27    fn len(&self) -> usize;
28
29    /// True when the length is zero
30    #[inline]
31    fn is_empty(&self) -> bool {
32        self.len() == 0
33    }
34
35    /// Uniformly sample one element
36    ///
37    /// Returns a reference to one uniformly-sampled random element of
38    /// the slice, or `None` if the slice is empty.
39    ///
40    /// For slices, complexity is `O(1)`.
41    ///
42    /// # Example
43    ///
44    /// ```
45    /// use rand::seq::IndexedRandom;
46    ///
47    /// let choices = [1, 2, 4, 8, 16, 32];
48    /// let mut rng = rand::rng();
49    /// println!("{:?}", choices.choose(&mut rng));
50    /// assert_eq!(choices[..0].choose(&mut rng), None);
51    /// ```
52    fn choose<R>(&self, rng: &mut R) -> Option<&Self::Output>
53    where
54        R: Rng + ?Sized,
55    {
56        if self.is_empty() {
57            None
58        } else {
59            Some(&self[rng.random_range(..self.len())])
60        }
61    }
62
63    /// Uniformly sample `amount` distinct elements from self
64    ///
65    /// Chooses `amount` elements from the slice at random, without repetition,
66    /// and in random order. The returned iterator is appropriate both for
67    /// collection into a `Vec` and filling an existing buffer (see example).
68    ///
69    /// In case this API is not sufficiently flexible, use [`index::sample`].
70    ///
71    /// For slices, complexity is the same as [`index::sample`].
72    ///
73    /// # Example
74    /// ```
75    /// use rand::seq::IndexedRandom;
76    ///
77    /// let mut rng = &mut rand::rng();
78    /// let sample = "Hello, audience!".as_bytes();
79    ///
80    /// // collect the results into a vector:
81    /// let v: Vec<u8> = sample.choose_multiple(&mut rng, 3).cloned().collect();
82    ///
83    /// // store in a buffer:
84    /// let mut buf = [0u8; 5];
85    /// for (b, slot) in sample.choose_multiple(&mut rng, buf.len()).zip(buf.iter_mut()) {
86    ///     *slot = *b;
87    /// }
88    /// ```
89    #[cfg(feature = "alloc")]
90    fn choose_multiple<R>(&self, rng: &mut R, amount: usize) -> SliceChooseIter<Self, Self::Output>
91    where
92        Self::Output: Sized,
93        R: Rng + ?Sized,
94    {
95        let amount = core::cmp::min(amount, self.len());
96        SliceChooseIter {
97            slice: self,
98            _phantom: Default::default(),
99            indices: index::sample(rng, self.len(), amount).into_iter(),
100        }
101    }
102
103    /// Uniformly sample a fixed-size array of distinct elements from self
104    ///
105    /// Chooses `N` elements from the slice at random, without repetition,
106    /// and in random order.
107    ///
108    /// For slices, complexity is the same as [`index::sample_array`].
109    ///
110    /// # Example
111    /// ```
112    /// use rand::seq::IndexedRandom;
113    ///
114    /// let mut rng = &mut rand::rng();
115    /// let sample = "Hello, audience!".as_bytes();
116    ///
117    /// let a: [u8; 3] = sample.choose_multiple_array(&mut rng).unwrap();
118    /// ```
119    fn choose_multiple_array<R, const N: usize>(&self, rng: &mut R) -> Option<[Self::Output; N]>
120    where
121        Self::Output: Clone + Sized,
122        R: Rng + ?Sized,
123    {
124        let indices = index::sample_array(rng, self.len())?;
125        Some(indices.map(|index| self[index].clone()))
126    }
127
128    /// Biased sampling for one element
129    ///
130    /// Returns a reference to one element of the slice, sampled according
131    /// to the provided weights. Returns `None` only if the slice is empty.
132    ///
133    /// The specified function `weight` maps each item `x` to a relative
134    /// likelihood `weight(x)`. The probability of each item being selected is
135    /// therefore `weight(x) / s`, where `s` is the sum of all `weight(x)`.
136    ///
137    /// For slices of length `n`, complexity is `O(n)`.
138    /// For more information about the underlying algorithm,
139    /// see the [`WeightedIndex`] distribution.
140    ///
141    /// See also [`choose_weighted_mut`].
142    ///
143    /// # Example
144    ///
145    /// ```
146    /// use rand::prelude::*;
147    ///
148    /// let choices = [('a', 2), ('b', 1), ('c', 1), ('d', 0)];
149    /// let mut rng = rand::rng();
150    /// // 50% chance to print 'a', 25% chance to print 'b', 25% chance to print 'c',
151    /// // and 'd' will never be printed
152    /// println!("{:?}", choices.choose_weighted(&mut rng, |item| item.1).unwrap().0);
153    /// ```
154    /// [`choose`]: IndexedRandom::choose
155    /// [`choose_weighted_mut`]: IndexedMutRandom::choose_weighted_mut
156    /// [`WeightedIndex`]: crate::distr::weighted::WeightedIndex
157    #[cfg(feature = "alloc")]
158    fn choose_weighted<R, F, B, X>(
159        &self,
160        rng: &mut R,
161        weight: F,
162    ) -> Result<&Self::Output, WeightError>
163    where
164        R: Rng + ?Sized,
165        F: Fn(&Self::Output) -> B,
166        B: SampleBorrow<X>,
167        X: SampleUniform + Weight + PartialOrd<X>,
168    {
169        use crate::distr::{weighted::WeightedIndex, Distribution};
170        let distr = WeightedIndex::new((0..self.len()).map(|idx| weight(&self[idx])))?;
171        Ok(&self[distr.sample(rng)])
172    }
173
174    /// Biased sampling of `amount` distinct elements
175    ///
176    /// Similar to [`choose_multiple`], but where the likelihood of each
177    /// element's inclusion in the output may be specified. Zero-weighted
178    /// elements are never returned; the result may therefore contain fewer
179    /// elements than `amount` even when `self.len() >= amount`. The elements
180    /// are returned in an arbitrary, unspecified order.
181    ///
182    /// The specified function `weight` maps each item `x` to a relative
183    /// likelihood `weight(x)`. The probability of each item being selected is
184    /// therefore `weight(x) / s`, where `s` is the sum of all `weight(x)`.
185    ///
186    /// This implementation uses `O(length + amount)` space and `O(length)` time.
187    /// See [`index::sample_weighted`] for details.
188    ///
189    /// # Example
190    ///
191    /// ```
192    /// use rand::prelude::*;
193    ///
194    /// let choices = [('a', 2), ('b', 1), ('c', 1)];
195    /// let mut rng = rand::rng();
196    /// // First Draw * Second Draw = total odds
197    /// // -----------------------
198    /// // (50% * 50%) + (25% * 67%) = 41.7% chance that the output is `['a', 'b']` in some order.
199    /// // (50% * 50%) + (25% * 67%) = 41.7% chance that the output is `['a', 'c']` in some order.
200    /// // (25% * 33%) + (25% * 33%) = 16.6% chance that the output is `['b', 'c']` in some order.
201    /// println!("{:?}", choices.choose_multiple_weighted(&mut rng, 2, |item| item.1).unwrap().collect::<Vec<_>>());
202    /// ```
203    /// [`choose_multiple`]: IndexedRandom::choose_multiple
204    // Note: this is feature-gated on std due to usage of f64::powf.
205    // If necessary, we may use alloc+libm as an alternative (see PR #1089).
206    #[cfg(feature = "std")]
207    fn choose_multiple_weighted<R, F, X>(
208        &self,
209        rng: &mut R,
210        amount: usize,
211        weight: F,
212    ) -> Result<SliceChooseIter<Self, Self::Output>, WeightError>
213    where
214        Self::Output: Sized,
215        R: Rng + ?Sized,
216        F: Fn(&Self::Output) -> X,
217        X: Into<f64>,
218    {
219        let amount = core::cmp::min(amount, self.len());
220        Ok(SliceChooseIter {
221            slice: self,
222            _phantom: Default::default(),
223            indices: index::sample_weighted(
224                rng,
225                self.len(),
226                |idx| weight(&self[idx]).into(),
227                amount,
228            )?
229            .into_iter(),
230        })
231    }
232}
233
234/// Extension trait on indexable lists, providing random sampling methods.
235///
236/// This trait is implemented automatically for every type implementing
237/// [`IndexedRandom`] and [`std::ops::IndexMut<usize>`].
238pub trait IndexedMutRandom: IndexedRandom + IndexMut<usize> {
239    /// Uniformly sample one element (mut)
240    ///
241    /// Returns a mutable reference to one uniformly-sampled random element of
242    /// the slice, or `None` if the slice is empty.
243    ///
244    /// For slices, complexity is `O(1)`.
245    fn choose_mut<R>(&mut self, rng: &mut R) -> Option<&mut Self::Output>
246    where
247        R: Rng + ?Sized,
248    {
249        if self.is_empty() {
250            None
251        } else {
252            let len = self.len();
253            Some(&mut self[rng.random_range(..len)])
254        }
255    }
256
257    /// Biased sampling for one element (mut)
258    ///
259    /// Returns a mutable reference to one element of the slice, sampled according
260    /// to the provided weights. Returns `None` only if the slice is empty.
261    ///
262    /// The specified function `weight` maps each item `x` to a relative
263    /// likelihood `weight(x)`. The probability of each item being selected is
264    /// therefore `weight(x) / s`, where `s` is the sum of all `weight(x)`.
265    ///
266    /// For slices of length `n`, complexity is `O(n)`.
267    /// For more information about the underlying algorithm,
268    /// see the [`WeightedIndex`] distribution.
269    ///
270    /// See also [`choose_weighted`].
271    ///
272    /// [`choose_mut`]: IndexedMutRandom::choose_mut
273    /// [`choose_weighted`]: IndexedRandom::choose_weighted
274    /// [`WeightedIndex`]: crate::distr::weighted::WeightedIndex
275    #[cfg(feature = "alloc")]
276    fn choose_weighted_mut<R, F, B, X>(
277        &mut self,
278        rng: &mut R,
279        weight: F,
280    ) -> Result<&mut Self::Output, WeightError>
281    where
282        R: Rng + ?Sized,
283        F: Fn(&Self::Output) -> B,
284        B: SampleBorrow<X>,
285        X: SampleUniform + Weight + PartialOrd<X>,
286    {
287        use crate::distr::{weighted::WeightedIndex, Distribution};
288        let distr = WeightedIndex::new((0..self.len()).map(|idx| weight(&self[idx])))?;
289        let index = distr.sample(rng);
290        Ok(&mut self[index])
291    }
292}
293
294/// Extension trait on slices, providing shuffling methods.
295///
296/// This trait is implemented on all `[T]` slice types, providing several
297/// methods for choosing and shuffling elements. You must `use` this trait:
298///
299/// ```
300/// use rand::seq::SliceRandom;
301///
302/// let mut rng = rand::rng();
303/// let mut bytes = "Hello, random!".to_string().into_bytes();
304/// bytes.shuffle(&mut rng);
305/// let str = String::from_utf8(bytes).unwrap();
306/// println!("{}", str);
307/// ```
308/// Example output (non-deterministic):
309/// ```none
310/// l,nmroHado !le
311/// ```
312pub trait SliceRandom: IndexedMutRandom {
313    /// Shuffle a mutable slice in place.
314    ///
315    /// For slices of length `n`, complexity is `O(n)`.
316    /// The resulting permutation is picked uniformly from the set of all possible permutations.
317    ///
318    /// # Example
319    ///
320    /// ```
321    /// use rand::seq::SliceRandom;
322    ///
323    /// let mut rng = rand::rng();
324    /// let mut y = [1, 2, 3, 4, 5];
325    /// println!("Unshuffled: {:?}", y);
326    /// y.shuffle(&mut rng);
327    /// println!("Shuffled:   {:?}", y);
328    /// ```
329    fn shuffle<R>(&mut self, rng: &mut R)
330    where
331        R: Rng + ?Sized;
332
333    /// Shuffle a slice in place, but exit early.
334    ///
335    /// Returns two mutable slices from the source slice. The first contains
336    /// `amount` elements randomly permuted. The second has the remaining
337    /// elements that are not fully shuffled.
338    ///
339    /// This is an efficient method to select `amount` elements at random from
340    /// the slice, provided the slice may be mutated.
341    ///
342    /// If you only need to choose elements randomly and `amount > self.len()/2`
343    /// then you may improve performance by taking
344    /// `amount = self.len() - amount` and using only the second slice.
345    ///
346    /// If `amount` is greater than the number of elements in the slice, this
347    /// will perform a full shuffle.
348    ///
349    /// For slices, complexity is `O(m)` where `m = amount`.
350    fn partial_shuffle<R>(
351        &mut self,
352        rng: &mut R,
353        amount: usize,
354    ) -> (&mut [Self::Output], &mut [Self::Output])
355    where
356        Self::Output: Sized,
357        R: Rng + ?Sized;
358}
359
360impl<T> IndexedRandom for [T] {
361    fn len(&self) -> usize {
362        self.len()
363    }
364}
365
366impl<IR: IndexedRandom + IndexMut<usize> + ?Sized> IndexedMutRandom for IR {}
367
368impl<T> SliceRandom for [T] {
369    fn shuffle<R>(&mut self, rng: &mut R)
370    where
371        R: Rng + ?Sized,
372    {
373        if self.len() <= 1 {
374            // There is no need to shuffle an empty or single element slice
375            return;
376        }
377        self.partial_shuffle(rng, self.len());
378    }
379
380    fn partial_shuffle<R>(&mut self, rng: &mut R, amount: usize) -> (&mut [T], &mut [T])
381    where
382        R: Rng + ?Sized,
383    {
384        let m = self.len().saturating_sub(amount);
385
386        // The algorithm below is based on Durstenfeld's algorithm for the
387        // [Fisher–Yates shuffle](https://en.wikipedia.org/wiki/Fisher%E2%80%93Yates_shuffle#The_modern_algorithm)
388        // for an unbiased permutation.
389        // It ensures that the last `amount` elements of the slice
390        // are randomly selected from the whole slice.
391
392        // `IncreasingUniform::next_index()` is faster than `Rng::random_range`
393        // but only works for 32 bit integers
394        // So we must use the slow method if the slice is longer than that.
395        if self.len() < (u32::MAX as usize) {
396            let mut chooser = IncreasingUniform::new(rng, m as u32);
397            for i in m..self.len() {
398                let index = chooser.next_index();
399                self.swap(i, index);
400            }
401        } else {
402            for i in m..self.len() {
403                let index = rng.random_range(..i + 1);
404                self.swap(i, index);
405            }
406        }
407        let r = self.split_at_mut(m);
408        (r.1, r.0)
409    }
410}
411
412/// An iterator over multiple slice elements.
413///
414/// This struct is created by
415/// [`IndexedRandom::choose_multiple`](trait.IndexedRandom.html#tymethod.choose_multiple).
416#[cfg(feature = "alloc")]
417#[derive(Debug)]
418pub struct SliceChooseIter<'a, S: ?Sized + 'a, T: 'a> {
419    slice: &'a S,
420    _phantom: core::marker::PhantomData<T>,
421    indices: index::IndexVecIntoIter,
422}
423
424#[cfg(feature = "alloc")]
425impl<'a, S: Index<usize, Output = T> + ?Sized + 'a, T: 'a> Iterator for SliceChooseIter<'a, S, T> {
426    type Item = &'a T;
427
428    fn next(&mut self) -> Option<Self::Item> {
429        // TODO: investigate using SliceIndex::get_unchecked when stable
430        self.indices.next().map(|i| &self.slice[i])
431    }
432
433    fn size_hint(&self) -> (usize, Option<usize>) {
434        (self.indices.len(), Some(self.indices.len()))
435    }
436}
437
438#[cfg(feature = "alloc")]
439impl<'a, S: Index<usize, Output = T> + ?Sized + 'a, T: 'a> ExactSizeIterator
440    for SliceChooseIter<'a, S, T>
441{
442    fn len(&self) -> usize {
443        self.indices.len()
444    }
445}
446
447#[cfg(test)]
448mod test {
449    use super::*;
450    #[cfg(feature = "alloc")]
451    use alloc::vec::Vec;
452
453    #[test]
454    fn test_slice_choose() {
455        let mut r = crate::test::rng(107);
456        let chars = [
457            'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n',
458        ];
459        let mut chosen = [0i32; 14];
460        // The below all use a binomial distribution with n=1000, p=1/14.
461        // binocdf(40, 1000, 1/14) ~= 2e-5; 1-binocdf(106, ..) ~= 2e-5
462        for _ in 0..1000 {
463            let picked = *chars.choose(&mut r).unwrap();
464            chosen[(picked as usize) - ('a' as usize)] += 1;
465        }
466        for count in chosen.iter() {
467            assert!(40 < *count && *count < 106);
468        }
469
470        chosen.iter_mut().for_each(|x| *x = 0);
471        for _ in 0..1000 {
472            *chosen.choose_mut(&mut r).unwrap() += 1;
473        }
474        for count in chosen.iter() {
475            assert!(40 < *count && *count < 106);
476        }
477
478        let mut v: [isize; 0] = [];
479        assert_eq!(v.choose(&mut r), None);
480        assert_eq!(v.choose_mut(&mut r), None);
481    }
482
483    #[test]
484    fn value_stability_slice() {
485        let mut r = crate::test::rng(413);
486        let chars = [
487            'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n',
488        ];
489        let mut nums = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12];
490
491        assert_eq!(chars.choose(&mut r), Some(&'l'));
492        assert_eq!(nums.choose_mut(&mut r), Some(&mut 3));
493
494        assert_eq!(
495            &chars.choose_multiple_array(&mut r),
496            &Some(['f', 'i', 'd', 'b', 'c', 'm', 'j', 'k'])
497        );
498
499        #[cfg(feature = "alloc")]
500        assert_eq!(
501            &chars
502                .choose_multiple(&mut r, 8)
503                .cloned()
504                .collect::<Vec<char>>(),
505            &['h', 'm', 'd', 'b', 'c', 'e', 'n', 'f']
506        );
507
508        #[cfg(feature = "alloc")]
509        assert_eq!(chars.choose_weighted(&mut r, |_| 1), Ok(&'i'));
510        #[cfg(feature = "alloc")]
511        assert_eq!(nums.choose_weighted_mut(&mut r, |_| 1), Ok(&mut 2));
512
513        let mut r = crate::test::rng(414);
514        nums.shuffle(&mut r);
515        assert_eq!(nums, [5, 11, 0, 8, 7, 12, 6, 4, 9, 3, 1, 2, 10]);
516        nums = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12];
517        let res = nums.partial_shuffle(&mut r, 6);
518        assert_eq!(res.0, &mut [7, 12, 6, 8, 1, 9]);
519        assert_eq!(res.1, &mut [0, 11, 2, 3, 4, 5, 10]);
520    }
521
522    #[test]
523    #[cfg_attr(miri, ignore)] // Miri is too slow
524    fn test_shuffle() {
525        let mut r = crate::test::rng(108);
526        let empty: &mut [isize] = &mut [];
527        empty.shuffle(&mut r);
528        let mut one = [1];
529        one.shuffle(&mut r);
530        let b: &[_] = &[1];
531        assert_eq!(one, b);
532
533        let mut two = [1, 2];
534        two.shuffle(&mut r);
535        assert!(two == [1, 2] || two == [2, 1]);
536
537        fn move_last(slice: &mut [usize], pos: usize) {
538            // use slice[pos..].rotate_left(1); once we can use that
539            let last_val = slice[pos];
540            for i in pos..slice.len() - 1 {
541                slice[i] = slice[i + 1];
542            }
543            *slice.last_mut().unwrap() = last_val;
544        }
545        let mut counts = [0i32; 24];
546        for _ in 0..10000 {
547            let mut arr: [usize; 4] = [0, 1, 2, 3];
548            arr.shuffle(&mut r);
549            let mut permutation = 0usize;
550            let mut pos_value = counts.len();
551            for i in 0..4 {
552                pos_value /= 4 - i;
553                let pos = arr.iter().position(|&x| x == i).unwrap();
554                assert!(pos < (4 - i));
555                permutation += pos * pos_value;
556                move_last(&mut arr, pos);
557                assert_eq!(arr[3], i);
558            }
559            for (i, &a) in arr.iter().enumerate() {
560                assert_eq!(a, i);
561            }
562            counts[permutation] += 1;
563        }
564        for count in counts.iter() {
565            // Binomial(10000, 1/24) with average 416.667
566            // Octave: binocdf(n, 10000, 1/24)
567            // 99.9% chance samples lie within this range:
568            assert!(352 <= *count && *count <= 483, "count: {}", count);
569        }
570    }
571
572    #[test]
573    fn test_partial_shuffle() {
574        let mut r = crate::test::rng(118);
575
576        let mut empty: [u32; 0] = [];
577        let res = empty.partial_shuffle(&mut r, 10);
578        assert_eq!((res.0.len(), res.1.len()), (0, 0));
579
580        let mut v = [1, 2, 3, 4, 5];
581        let res = v.partial_shuffle(&mut r, 2);
582        assert_eq!((res.0.len(), res.1.len()), (2, 3));
583        assert!(res.0[0] != res.0[1]);
584        // First elements are only modified if selected, so at least one isn't modified:
585        assert!(res.1[0] == 1 || res.1[1] == 2 || res.1[2] == 3);
586    }
587
588    #[test]
589    #[cfg(feature = "alloc")]
590    #[cfg_attr(miri, ignore)] // Miri is too slow
591    fn test_weighted() {
592        let mut r = crate::test::rng(406);
593        const N_REPS: u32 = 3000;
594        let weights = [1u32, 2, 3, 0, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7];
595        let total_weight = weights.iter().sum::<u32>() as f32;
596
597        let verify = |result: [i32; 14]| {
598            for (i, count) in result.iter().enumerate() {
599                let exp = (weights[i] * N_REPS) as f32 / total_weight;
600                let mut err = (*count as f32 - exp).abs();
601                if err != 0.0 {
602                    err /= exp;
603                }
604                assert!(err <= 0.25);
605            }
606        };
607
608        // choose_weighted
609        fn get_weight<T>(item: &(u32, T)) -> u32 {
610            item.0
611        }
612        let mut chosen = [0i32; 14];
613        let mut items = [(0u32, 0usize); 14]; // (weight, index)
614        for (i, item) in items.iter_mut().enumerate() {
615            *item = (weights[i], i);
616        }
617        for _ in 0..N_REPS {
618            let item = items.choose_weighted(&mut r, get_weight).unwrap();
619            chosen[item.1] += 1;
620        }
621        verify(chosen);
622
623        // choose_weighted_mut
624        let mut items = [(0u32, 0i32); 14]; // (weight, count)
625        for (i, item) in items.iter_mut().enumerate() {
626            *item = (weights[i], 0);
627        }
628        for _ in 0..N_REPS {
629            items.choose_weighted_mut(&mut r, get_weight).unwrap().1 += 1;
630        }
631        for (ch, item) in chosen.iter_mut().zip(items.iter()) {
632            *ch = item.1;
633        }
634        verify(chosen);
635
636        // Check error cases
637        let empty_slice = &mut [10][0..0];
638        assert_eq!(
639            empty_slice.choose_weighted(&mut r, |_| 1),
640            Err(WeightError::InvalidInput)
641        );
642        assert_eq!(
643            empty_slice.choose_weighted_mut(&mut r, |_| 1),
644            Err(WeightError::InvalidInput)
645        );
646        assert_eq!(
647            ['x'].choose_weighted_mut(&mut r, |_| 0),
648            Err(WeightError::InsufficientNonZero)
649        );
650        assert_eq!(
651            [0, -1].choose_weighted_mut(&mut r, |x| *x),
652            Err(WeightError::InvalidWeight)
653        );
654        assert_eq!(
655            [-1, 0].choose_weighted_mut(&mut r, |x| *x),
656            Err(WeightError::InvalidWeight)
657        );
658    }
659
660    #[test]
661    #[cfg(feature = "std")]
662    fn test_multiple_weighted_edge_cases() {
663        use super::*;
664
665        let mut rng = crate::test::rng(413);
666
667        // Case 1: One of the weights is 0
668        let choices = [('a', 2), ('b', 1), ('c', 0)];
669        for _ in 0..100 {
670            let result = choices
671                .choose_multiple_weighted(&mut rng, 2, |item| item.1)
672                .unwrap()
673                .collect::<Vec<_>>();
674
675            assert_eq!(result.len(), 2);
676            assert!(!result.iter().any(|val| val.0 == 'c'));
677        }
678
679        // Case 2: All of the weights are 0
680        let choices = [('a', 0), ('b', 0), ('c', 0)];
681        let r = choices.choose_multiple_weighted(&mut rng, 2, |item| item.1);
682        assert_eq!(r.unwrap().len(), 0);
683
684        // Case 3: Negative weights
685        let choices = [('a', -1), ('b', 1), ('c', 1)];
686        let r = choices.choose_multiple_weighted(&mut rng, 2, |item| item.1);
687        assert_eq!(r.unwrap_err(), WeightError::InvalidWeight);
688
689        // Case 4: Empty list
690        let choices = [];
691        let r = choices.choose_multiple_weighted(&mut rng, 0, |_: &()| 0);
692        assert_eq!(r.unwrap().count(), 0);
693
694        // Case 5: NaN weights
695        let choices = [('a', f64::NAN), ('b', 1.0), ('c', 1.0)];
696        let r = choices.choose_multiple_weighted(&mut rng, 2, |item| item.1);
697        assert_eq!(r.unwrap_err(), WeightError::InvalidWeight);
698
699        // Case 6: +infinity weights
700        let choices = [('a', f64::INFINITY), ('b', 1.0), ('c', 1.0)];
701        for _ in 0..100 {
702            let result = choices
703                .choose_multiple_weighted(&mut rng, 2, |item| item.1)
704                .unwrap()
705                .collect::<Vec<_>>();
706            assert_eq!(result.len(), 2);
707            assert!(result.iter().any(|val| val.0 == 'a'));
708        }
709
710        // Case 7: -infinity weights
711        let choices = [('a', f64::NEG_INFINITY), ('b', 1.0), ('c', 1.0)];
712        let r = choices.choose_multiple_weighted(&mut rng, 2, |item| item.1);
713        assert_eq!(r.unwrap_err(), WeightError::InvalidWeight);
714
715        // Case 8: -0 weights
716        let choices = [('a', -0.0), ('b', 1.0), ('c', 1.0)];
717        let r = choices.choose_multiple_weighted(&mut rng, 2, |item| item.1);
718        assert!(r.is_ok());
719    }
720
721    #[test]
722    #[cfg(feature = "std")]
723    fn test_multiple_weighted_distributions() {
724        use super::*;
725
726        // The theoretical probabilities of the different outcomes are:
727        // AB: 0.5   * 0.667 = 0.3333
728        // AC: 0.5   * 0.333 = 0.1667
729        // BA: 0.333 * 0.75  = 0.25
730        // BC: 0.333 * 0.25  = 0.0833
731        // CA: 0.167 * 0.6   = 0.1
732        // CB: 0.167 * 0.4   = 0.0667
733        let choices = [('a', 3), ('b', 2), ('c', 1)];
734        let mut rng = crate::test::rng(414);
735
736        let mut results = [0i32; 3];
737        let expected_results = [5833, 2667, 1500];
738        for _ in 0..10000 {
739            let result = choices
740                .choose_multiple_weighted(&mut rng, 2, |item| item.1)
741                .unwrap()
742                .collect::<Vec<_>>();
743
744            assert_eq!(result.len(), 2);
745
746            match (result[0].0, result[1].0) {
747                ('a', 'b') | ('b', 'a') => {
748                    results[0] += 1;
749                }
750                ('a', 'c') | ('c', 'a') => {
751                    results[1] += 1;
752                }
753                ('b', 'c') | ('c', 'b') => {
754                    results[2] += 1;
755                }
756                (_, _) => panic!("unexpected result"),
757            }
758        }
759
760        let mut diffs = results
761            .iter()
762            .zip(&expected_results)
763            .map(|(a, b)| (a - b).abs());
764        assert!(!diffs.any(|deviation| deviation > 100));
765    }
766}