rand/seq/
iterator.rs

1// Copyright 2018-2024 Developers of the Rand project.
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
6// option. This file may not be copied, modified, or distributed
7// except according to those terms.
8
9//! `IteratorRandom`
10
11use super::coin_flipper::CoinFlipper;
12#[allow(unused)]
13use super::IndexedRandom;
14use crate::Rng;
15#[cfg(feature = "alloc")]
16use alloc::vec::Vec;
17
18/// Extension trait on iterators, providing random sampling methods.
19///
20/// This trait is implemented on all iterators `I` where `I: Iterator + Sized`
21/// and provides methods for
22/// choosing one or more elements. You must `use` this trait:
23///
24/// ```
25/// use rand::seq::IteratorRandom;
26///
27/// let faces = "😀😎😐😕😠😢";
28/// println!("I am {}!", faces.chars().choose(&mut rand::rng()).unwrap());
29/// ```
30/// Example output (non-deterministic):
31/// ```none
32/// I am 😀!
33/// ```
34pub trait IteratorRandom: Iterator + Sized {
35    /// Uniformly sample one element
36    ///
37    /// Assuming that the [`Iterator::size_hint`] is correct, this method
38    /// returns one uniformly-sampled random element of the slice, or `None`
39    /// only if the slice is empty. Incorrect bounds on the `size_hint` may
40    /// cause this method to incorrectly return `None` if fewer elements than
41    /// the advertised `lower` bound are present and may prevent sampling of
42    /// elements beyond an advertised `upper` bound (i.e. incorrect `size_hint`
43    /// is memory-safe, but may result in unexpected `None` result and
44    /// non-uniform distribution).
45    ///
46    /// With an accurate [`Iterator::size_hint`] and where [`Iterator::nth`] is
47    /// a constant-time operation, this method can offer `O(1)` performance.
48    /// Where no size hint is
49    /// available, complexity is `O(n)` where `n` is the iterator length.
50    /// Partial hints (where `lower > 0`) also improve performance.
51    ///
52    /// Note further that [`Iterator::size_hint`] may affect the number of RNG
53    /// samples used as well as the result (while remaining uniform sampling).
54    /// Consider instead using [`IteratorRandom::choose_stable`] to avoid
55    /// [`Iterator`] combinators which only change size hints from affecting the
56    /// results.
57    ///
58    /// # Example
59    ///
60    /// ```
61    /// use rand::seq::IteratorRandom;
62    ///
63    /// let words = "Mary had a little lamb".split(' ');
64    /// println!("{}", words.choose(&mut rand::rng()).unwrap());
65    /// ```
66    fn choose<R>(mut self, rng: &mut R) -> Option<Self::Item>
67    where
68        R: Rng + ?Sized,
69    {
70        let (mut lower, mut upper) = self.size_hint();
71        let mut result = None;
72
73        // Handling for this condition outside the loop allows the optimizer to eliminate the loop
74        // when the Iterator is an ExactSizeIterator. This has a large performance impact on e.g.
75        // seq_iter_choose_from_1000.
76        if upper == Some(lower) {
77            return match lower {
78                0 => None,
79                1 => self.next(),
80                _ => self.nth(rng.random_range(..lower)),
81            };
82        }
83
84        let mut coin_flipper = CoinFlipper::new(rng);
85        let mut consumed = 0;
86
87        // Continue until the iterator is exhausted
88        loop {
89            if lower > 1 {
90                let ix = coin_flipper.rng.random_range(..lower + consumed);
91                let skip = if ix < lower {
92                    result = self.nth(ix);
93                    lower - (ix + 1)
94                } else {
95                    lower
96                };
97                if upper == Some(lower) {
98                    return result;
99                }
100                consumed += lower;
101                if skip > 0 {
102                    self.nth(skip - 1);
103                }
104            } else {
105                let elem = self.next();
106                if elem.is_none() {
107                    return result;
108                }
109                consumed += 1;
110                if coin_flipper.random_ratio_one_over(consumed) {
111                    result = elem;
112                }
113            }
114
115            let hint = self.size_hint();
116            lower = hint.0;
117            upper = hint.1;
118        }
119    }
120
121    /// Uniformly sample one element (stable)
122    ///
123    /// This method is very similar to [`choose`] except that the result
124    /// only depends on the length of the iterator and the values produced by
125    /// `rng`. Notably for any iterator of a given length this will make the
126    /// same requests to `rng` and if the same sequence of values are produced
127    /// the same index will be selected from `self`. This may be useful if you
128    /// need consistent results no matter what type of iterator you are working
129    /// with. If you do not need this stability prefer [`choose`].
130    ///
131    /// Note that this method still uses [`Iterator::size_hint`] to skip
132    /// constructing elements where possible, however the selection and `rng`
133    /// calls are the same in the face of this optimization. If you want to
134    /// force every element to be created regardless call `.inspect(|e| ())`.
135    ///
136    /// [`choose`]: IteratorRandom::choose
137    //
138    // Clippy is wrong here: we need to iterate over all entries with the RNG to
139    // ensure that choosing is *stable*.
140    #[allow(clippy::double_ended_iterator_last)]
141    fn choose_stable<R>(mut self, rng: &mut R) -> Option<Self::Item>
142    where
143        R: Rng + ?Sized,
144    {
145        let mut consumed = 0;
146        let mut result = None;
147        let mut coin_flipper = CoinFlipper::new(rng);
148
149        loop {
150            // Currently the only way to skip elements is `nth()`. So we need to
151            // store what index to access next here.
152            // This should be replaced by `advance_by()` once it is stable:
153            // https://github.com/rust-lang/rust/issues/77404
154            let mut next = 0;
155
156            let (lower, _) = self.size_hint();
157            if lower >= 2 {
158                let highest_selected = (0..lower)
159                    .filter(|ix| coin_flipper.random_ratio_one_over(consumed + ix + 1))
160                    .last();
161
162                consumed += lower;
163                next = lower;
164
165                if let Some(ix) = highest_selected {
166                    result = self.nth(ix);
167                    next -= ix + 1;
168                    debug_assert!(result.is_some(), "iterator shorter than size_hint().0");
169                }
170            }
171
172            let elem = self.nth(next);
173            if elem.is_none() {
174                return result;
175            }
176
177            if coin_flipper.random_ratio_one_over(consumed + 1) {
178                result = elem;
179            }
180            consumed += 1;
181        }
182    }
183
184    /// Uniformly sample `amount` distinct elements into a buffer
185    ///
186    /// Collects values at random from the iterator into a supplied buffer
187    /// until that buffer is filled.
188    ///
189    /// Although the elements are selected randomly, the order of elements in
190    /// the buffer is neither stable nor fully random. If random ordering is
191    /// desired, shuffle the result.
192    ///
193    /// Returns the number of elements added to the buffer. This equals the length
194    /// of the buffer unless the iterator contains insufficient elements, in which
195    /// case this equals the number of elements available.
196    ///
197    /// Complexity is `O(n)` where `n` is the length of the iterator.
198    /// For slices, prefer [`IndexedRandom::choose_multiple`].
199    fn choose_multiple_fill<R>(mut self, rng: &mut R, buf: &mut [Self::Item]) -> usize
200    where
201        R: Rng + ?Sized,
202    {
203        let amount = buf.len();
204        let mut len = 0;
205        while len < amount {
206            if let Some(elem) = self.next() {
207                buf[len] = elem;
208                len += 1;
209            } else {
210                // Iterator exhausted; stop early
211                return len;
212            }
213        }
214
215        // Continue, since the iterator was not exhausted
216        for (i, elem) in self.enumerate() {
217            let k = rng.random_range(..i + 1 + amount);
218            if let Some(slot) = buf.get_mut(k) {
219                *slot = elem;
220            }
221        }
222        len
223    }
224
225    /// Uniformly sample `amount` distinct elements into a [`Vec`]
226    ///
227    /// This is equivalent to `choose_multiple_fill` except for the result type.
228    ///
229    /// Although the elements are selected randomly, the order of elements in
230    /// the buffer is neither stable nor fully random. If random ordering is
231    /// desired, shuffle the result.
232    ///
233    /// The length of the returned vector equals `amount` unless the iterator
234    /// contains insufficient elements, in which case it equals the number of
235    /// elements available.
236    ///
237    /// Complexity is `O(n)` where `n` is the length of the iterator.
238    /// For slices, prefer [`IndexedRandom::choose_multiple`].
239    #[cfg(feature = "alloc")]
240    fn choose_multiple<R>(mut self, rng: &mut R, amount: usize) -> Vec<Self::Item>
241    where
242        R: Rng + ?Sized,
243    {
244        let mut reservoir = Vec::with_capacity(amount);
245        reservoir.extend(self.by_ref().take(amount));
246
247        // Continue unless the iterator was exhausted
248        //
249        // note: this prevents iterators that "restart" from causing problems.
250        // If the iterator stops once, then so do we.
251        if reservoir.len() == amount {
252            for (i, elem) in self.enumerate() {
253                let k = rng.random_range(..i + 1 + amount);
254                if let Some(slot) = reservoir.get_mut(k) {
255                    *slot = elem;
256                }
257            }
258        } else {
259            // Don't hang onto extra memory. There is a corner case where
260            // `amount` was much less than `self.len()`.
261            reservoir.shrink_to_fit();
262        }
263        reservoir
264    }
265}
266
267impl<I> IteratorRandom for I where I: Iterator + Sized {}
268
269#[cfg(test)]
270mod test {
271    use super::*;
272    #[cfg(all(feature = "alloc", not(feature = "std")))]
273    use alloc::vec::Vec;
274
275    #[derive(Clone)]
276    struct UnhintedIterator<I: Iterator + Clone> {
277        iter: I,
278    }
279    impl<I: Iterator + Clone> Iterator for UnhintedIterator<I> {
280        type Item = I::Item;
281
282        fn next(&mut self) -> Option<Self::Item> {
283            self.iter.next()
284        }
285    }
286
287    #[derive(Clone)]
288    struct ChunkHintedIterator<I: ExactSizeIterator + Iterator + Clone> {
289        iter: I,
290        chunk_remaining: usize,
291        chunk_size: usize,
292        hint_total_size: bool,
293    }
294    impl<I: ExactSizeIterator + Iterator + Clone> Iterator for ChunkHintedIterator<I> {
295        type Item = I::Item;
296
297        fn next(&mut self) -> Option<Self::Item> {
298            if self.chunk_remaining == 0 {
299                self.chunk_remaining = core::cmp::min(self.chunk_size, self.iter.len());
300            }
301            self.chunk_remaining = self.chunk_remaining.saturating_sub(1);
302
303            self.iter.next()
304        }
305
306        fn size_hint(&self) -> (usize, Option<usize>) {
307            (
308                self.chunk_remaining,
309                if self.hint_total_size {
310                    Some(self.iter.len())
311                } else {
312                    None
313                },
314            )
315        }
316    }
317
318    #[derive(Clone)]
319    struct WindowHintedIterator<I: ExactSizeIterator + Iterator + Clone> {
320        iter: I,
321        window_size: usize,
322        hint_total_size: bool,
323    }
324    impl<I: ExactSizeIterator + Iterator + Clone> Iterator for WindowHintedIterator<I> {
325        type Item = I::Item;
326
327        fn next(&mut self) -> Option<Self::Item> {
328            self.iter.next()
329        }
330
331        fn size_hint(&self) -> (usize, Option<usize>) {
332            (
333                core::cmp::min(self.iter.len(), self.window_size),
334                if self.hint_total_size {
335                    Some(self.iter.len())
336                } else {
337                    None
338                },
339            )
340        }
341    }
342
343    #[test]
344    #[cfg_attr(miri, ignore)] // Miri is too slow
345    fn test_iterator_choose() {
346        let r = &mut crate::test::rng(109);
347        fn test_iter<R: Rng + ?Sized, Iter: Iterator<Item = usize> + Clone>(r: &mut R, iter: Iter) {
348            let mut chosen = [0i32; 9];
349            for _ in 0..1000 {
350                let picked = iter.clone().choose(r).unwrap();
351                chosen[picked] += 1;
352            }
353            for count in chosen.iter() {
354                // Samples should follow Binomial(1000, 1/9)
355                // Octave: binopdf(x, 1000, 1/9) gives the prob of *count == x
356                // Note: have seen 153, which is unlikely but not impossible.
357                assert!(
358                    72 < *count && *count < 154,
359                    "count not close to 1000/9: {}",
360                    count
361                );
362            }
363        }
364
365        test_iter(r, 0..9);
366        test_iter(r, [0, 1, 2, 3, 4, 5, 6, 7, 8].iter().cloned());
367        #[cfg(feature = "alloc")]
368        test_iter(r, (0..9).collect::<Vec<_>>().into_iter());
369        test_iter(r, UnhintedIterator { iter: 0..9 });
370        test_iter(
371            r,
372            ChunkHintedIterator {
373                iter: 0..9,
374                chunk_size: 4,
375                chunk_remaining: 4,
376                hint_total_size: false,
377            },
378        );
379        test_iter(
380            r,
381            ChunkHintedIterator {
382                iter: 0..9,
383                chunk_size: 4,
384                chunk_remaining: 4,
385                hint_total_size: true,
386            },
387        );
388        test_iter(
389            r,
390            WindowHintedIterator {
391                iter: 0..9,
392                window_size: 2,
393                hint_total_size: false,
394            },
395        );
396        test_iter(
397            r,
398            WindowHintedIterator {
399                iter: 0..9,
400                window_size: 2,
401                hint_total_size: true,
402            },
403        );
404
405        assert_eq!((0..0).choose(r), None);
406        assert_eq!(UnhintedIterator { iter: 0..0 }.choose(r), None);
407    }
408
409    #[test]
410    #[cfg_attr(miri, ignore)] // Miri is too slow
411    fn test_iterator_choose_stable() {
412        let r = &mut crate::test::rng(109);
413        fn test_iter<R: Rng + ?Sized, Iter: Iterator<Item = usize> + Clone>(r: &mut R, iter: Iter) {
414            let mut chosen = [0i32; 9];
415            for _ in 0..1000 {
416                let picked = iter.clone().choose_stable(r).unwrap();
417                chosen[picked] += 1;
418            }
419            for count in chosen.iter() {
420                // Samples should follow Binomial(1000, 1/9)
421                // Octave: binopdf(x, 1000, 1/9) gives the prob of *count == x
422                // Note: have seen 153, which is unlikely but not impossible.
423                assert!(
424                    72 < *count && *count < 154,
425                    "count not close to 1000/9: {}",
426                    count
427                );
428            }
429        }
430
431        test_iter(r, 0..9);
432        test_iter(r, [0, 1, 2, 3, 4, 5, 6, 7, 8].iter().cloned());
433        #[cfg(feature = "alloc")]
434        test_iter(r, (0..9).collect::<Vec<_>>().into_iter());
435        test_iter(r, UnhintedIterator { iter: 0..9 });
436        test_iter(
437            r,
438            ChunkHintedIterator {
439                iter: 0..9,
440                chunk_size: 4,
441                chunk_remaining: 4,
442                hint_total_size: false,
443            },
444        );
445        test_iter(
446            r,
447            ChunkHintedIterator {
448                iter: 0..9,
449                chunk_size: 4,
450                chunk_remaining: 4,
451                hint_total_size: true,
452            },
453        );
454        test_iter(
455            r,
456            WindowHintedIterator {
457                iter: 0..9,
458                window_size: 2,
459                hint_total_size: false,
460            },
461        );
462        test_iter(
463            r,
464            WindowHintedIterator {
465                iter: 0..9,
466                window_size: 2,
467                hint_total_size: true,
468            },
469        );
470
471        assert_eq!((0..0).choose(r), None);
472        assert_eq!(UnhintedIterator { iter: 0..0 }.choose(r), None);
473    }
474
475    #[test]
476    #[cfg_attr(miri, ignore)] // Miri is too slow
477    fn test_iterator_choose_stable_stability() {
478        fn test_iter(iter: impl Iterator<Item = usize> + Clone) -> [i32; 9] {
479            let r = &mut crate::test::rng(109);
480            let mut chosen = [0i32; 9];
481            for _ in 0..1000 {
482                let picked = iter.clone().choose_stable(r).unwrap();
483                chosen[picked] += 1;
484            }
485            chosen
486        }
487
488        let reference = test_iter(0..9);
489        assert_eq!(
490            test_iter([0, 1, 2, 3, 4, 5, 6, 7, 8].iter().cloned()),
491            reference
492        );
493
494        #[cfg(feature = "alloc")]
495        assert_eq!(test_iter((0..9).collect::<Vec<_>>().into_iter()), reference);
496        assert_eq!(test_iter(UnhintedIterator { iter: 0..9 }), reference);
497        assert_eq!(
498            test_iter(ChunkHintedIterator {
499                iter: 0..9,
500                chunk_size: 4,
501                chunk_remaining: 4,
502                hint_total_size: false,
503            }),
504            reference
505        );
506        assert_eq!(
507            test_iter(ChunkHintedIterator {
508                iter: 0..9,
509                chunk_size: 4,
510                chunk_remaining: 4,
511                hint_total_size: true,
512            }),
513            reference
514        );
515        assert_eq!(
516            test_iter(WindowHintedIterator {
517                iter: 0..9,
518                window_size: 2,
519                hint_total_size: false,
520            }),
521            reference
522        );
523        assert_eq!(
524            test_iter(WindowHintedIterator {
525                iter: 0..9,
526                window_size: 2,
527                hint_total_size: true,
528            }),
529            reference
530        );
531    }
532
533    #[test]
534    #[cfg(feature = "alloc")]
535    fn test_sample_iter() {
536        let min_val = 1;
537        let max_val = 100;
538
539        let mut r = crate::test::rng(401);
540        let vals = (min_val..max_val).collect::<Vec<i32>>();
541        let small_sample = vals.iter().choose_multiple(&mut r, 5);
542        let large_sample = vals.iter().choose_multiple(&mut r, vals.len() + 5);
543
544        assert_eq!(small_sample.len(), 5);
545        assert_eq!(large_sample.len(), vals.len());
546        // no randomization happens when amount >= len
547        assert_eq!(large_sample, vals.iter().collect::<Vec<_>>());
548
549        assert!(small_sample
550            .iter()
551            .all(|e| { **e >= min_val && **e <= max_val }));
552    }
553
554    #[test]
555    fn value_stability_choose() {
556        fn choose<I: Iterator<Item = u32>>(iter: I) -> Option<u32> {
557            let mut rng = crate::test::rng(411);
558            iter.choose(&mut rng)
559        }
560
561        assert_eq!(choose([].iter().cloned()), None);
562        assert_eq!(choose(0..100), Some(33));
563        assert_eq!(choose(UnhintedIterator { iter: 0..100 }), Some(27));
564        assert_eq!(
565            choose(ChunkHintedIterator {
566                iter: 0..100,
567                chunk_size: 32,
568                chunk_remaining: 32,
569                hint_total_size: false,
570            }),
571            Some(91)
572        );
573        assert_eq!(
574            choose(ChunkHintedIterator {
575                iter: 0..100,
576                chunk_size: 32,
577                chunk_remaining: 32,
578                hint_total_size: true,
579            }),
580            Some(91)
581        );
582        assert_eq!(
583            choose(WindowHintedIterator {
584                iter: 0..100,
585                window_size: 32,
586                hint_total_size: false,
587            }),
588            Some(34)
589        );
590        assert_eq!(
591            choose(WindowHintedIterator {
592                iter: 0..100,
593                window_size: 32,
594                hint_total_size: true,
595            }),
596            Some(34)
597        );
598    }
599
600    #[test]
601    fn value_stability_choose_stable() {
602        fn choose<I: Iterator<Item = u32>>(iter: I) -> Option<u32> {
603            let mut rng = crate::test::rng(411);
604            iter.choose_stable(&mut rng)
605        }
606
607        assert_eq!(choose([].iter().cloned()), None);
608        assert_eq!(choose(0..100), Some(27));
609        assert_eq!(choose(UnhintedIterator { iter: 0..100 }), Some(27));
610        assert_eq!(
611            choose(ChunkHintedIterator {
612                iter: 0..100,
613                chunk_size: 32,
614                chunk_remaining: 32,
615                hint_total_size: false,
616            }),
617            Some(27)
618        );
619        assert_eq!(
620            choose(ChunkHintedIterator {
621                iter: 0..100,
622                chunk_size: 32,
623                chunk_remaining: 32,
624                hint_total_size: true,
625            }),
626            Some(27)
627        );
628        assert_eq!(
629            choose(WindowHintedIterator {
630                iter: 0..100,
631                window_size: 32,
632                hint_total_size: false,
633            }),
634            Some(27)
635        );
636        assert_eq!(
637            choose(WindowHintedIterator {
638                iter: 0..100,
639                window_size: 32,
640                hint_total_size: true,
641            }),
642            Some(27)
643        );
644    }
645
646    #[test]
647    fn value_stability_choose_multiple() {
648        fn do_test<I: Clone + Iterator<Item = u32>>(iter: I, v: &[u32]) {
649            let mut rng = crate::test::rng(412);
650            let mut buf = [0u32; 8];
651            assert_eq!(
652                iter.clone().choose_multiple_fill(&mut rng, &mut buf),
653                v.len()
654            );
655            assert_eq!(&buf[0..v.len()], v);
656
657            #[cfg(feature = "alloc")]
658            {
659                let mut rng = crate::test::rng(412);
660                assert_eq!(iter.choose_multiple(&mut rng, v.len()), v);
661            }
662        }
663
664        do_test(0..4, &[0, 1, 2, 3]);
665        do_test(0..8, &[0, 1, 2, 3, 4, 5, 6, 7]);
666        do_test(0..100, &[77, 95, 38, 23, 25, 8, 58, 40]);
667    }
668}