rand/seq/
iterator.rs

1// Copyright 2018-2024 Developers of the Rand project.
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
6// option. This file may not be copied, modified, or distributed
7// except according to those terms.
8
9//! `IteratorRandom`
10
11use super::coin_flipper::CoinFlipper;
12#[allow(unused)]
13use super::IndexedRandom;
14use crate::Rng;
15#[cfg(feature = "alloc")]
16use alloc::vec::Vec;
17
18/// Extension trait on iterators, providing random sampling methods.
19///
20/// This trait is implemented on all iterators `I` where `I: Iterator + Sized`
21/// and provides methods for
22/// choosing one or more elements. You must `use` this trait:
23///
24/// ```
25/// use rand::seq::IteratorRandom;
26///
27/// let faces = "😀😎😐😕😠😢";
28/// println!("I am {}!", faces.chars().choose(&mut rand::rng()).unwrap());
29/// ```
30/// Example output (non-deterministic):
31/// ```none
32/// I am 😀!
33/// ```
34pub trait IteratorRandom: Iterator + Sized {
35    /// Uniformly sample one element
36    ///
37    /// Assuming that the [`Iterator::size_hint`] is correct, this method
38    /// returns one uniformly-sampled random element of the slice, or `None`
39    /// only if the slice is empty. Incorrect bounds on the `size_hint` may
40    /// cause this method to incorrectly return `None` if fewer elements than
41    /// the advertised `lower` bound are present and may prevent sampling of
42    /// elements beyond an advertised `upper` bound (i.e. incorrect `size_hint`
43    /// is memory-safe, but may result in unexpected `None` result and
44    /// non-uniform distribution).
45    ///
46    /// With an accurate [`Iterator::size_hint`] and where [`Iterator::nth`] is
47    /// a constant-time operation, this method can offer `O(1)` performance.
48    /// Where no size hint is
49    /// available, complexity is `O(n)` where `n` is the iterator length.
50    /// Partial hints (where `lower > 0`) also improve performance.
51    ///
52    /// Note further that [`Iterator::size_hint`] may affect the number of RNG
53    /// samples used as well as the result (while remaining uniform sampling).
54    /// Consider instead using [`IteratorRandom::choose_stable`] to avoid
55    /// [`Iterator`] combinators which only change size hints from affecting the
56    /// results.
57    ///
58    /// # Example
59    ///
60    /// ```
61    /// use rand::seq::IteratorRandom;
62    ///
63    /// let words = "Mary had a little lamb".split(' ');
64    /// println!("{}", words.choose(&mut rand::rng()).unwrap());
65    /// ```
66    fn choose<R>(mut self, rng: &mut R) -> Option<Self::Item>
67    where
68        R: Rng + ?Sized,
69    {
70        let (mut lower, mut upper) = self.size_hint();
71        let mut result = None;
72
73        // Handling for this condition outside the loop allows the optimizer to eliminate the loop
74        // when the Iterator is an ExactSizeIterator. This has a large performance impact on e.g.
75        // seq_iter_choose_from_1000.
76        if upper == Some(lower) {
77            return match lower {
78                0 => None,
79                1 => self.next(),
80                _ => self.nth(rng.random_range(..lower)),
81            };
82        }
83
84        let mut coin_flipper = CoinFlipper::new(rng);
85        let mut consumed = 0;
86
87        // Continue until the iterator is exhausted
88        loop {
89            if lower > 1 {
90                let ix = coin_flipper.rng.random_range(..lower + consumed);
91                let skip = if ix < lower {
92                    result = self.nth(ix);
93                    lower - (ix + 1)
94                } else {
95                    lower
96                };
97                if upper == Some(lower) {
98                    return result;
99                }
100                consumed += lower;
101                if skip > 0 {
102                    self.nth(skip - 1);
103                }
104            } else {
105                let elem = self.next();
106                if elem.is_none() {
107                    return result;
108                }
109                consumed += 1;
110                if coin_flipper.random_ratio_one_over(consumed) {
111                    result = elem;
112                }
113            }
114
115            let hint = self.size_hint();
116            lower = hint.0;
117            upper = hint.1;
118        }
119    }
120
121    /// Uniformly sample one element (stable)
122    ///
123    /// This method is very similar to [`choose`] except that the result
124    /// only depends on the length of the iterator and the values produced by
125    /// `rng`. Notably for any iterator of a given length this will make the
126    /// same requests to `rng` and if the same sequence of values are produced
127    /// the same index will be selected from `self`. This may be useful if you
128    /// need consistent results no matter what type of iterator you are working
129    /// with. If you do not need this stability prefer [`choose`].
130    ///
131    /// Note that this method still uses [`Iterator::size_hint`] to skip
132    /// constructing elements where possible, however the selection and `rng`
133    /// calls are the same in the face of this optimization. If you want to
134    /// force every element to be created regardless call `.inspect(|e| ())`.
135    ///
136    /// [`choose`]: IteratorRandom::choose
137    //
138    // Clippy is wrong here: we need to iterate over all entries with the RNG to
139    // ensure that choosing is *stable*.
140    // "allow(unknown_lints)" can be removed when switching to at least
141    // rust-version 1.86.0, see:
142    // https://rust-lang.github.io/rust-clippy/master/index.html#double_ended_iterator_last
143    #[allow(unknown_lints)]
144    #[allow(clippy::double_ended_iterator_last)]
145    fn choose_stable<R>(mut self, rng: &mut R) -> Option<Self::Item>
146    where
147        R: Rng + ?Sized,
148    {
149        let mut consumed = 0;
150        let mut result = None;
151        let mut coin_flipper = CoinFlipper::new(rng);
152
153        loop {
154            // Currently the only way to skip elements is `nth()`. So we need to
155            // store what index to access next here.
156            // This should be replaced by `advance_by()` once it is stable:
157            // https://github.com/rust-lang/rust/issues/77404
158            let mut next = 0;
159
160            let (lower, _) = self.size_hint();
161            if lower >= 2 {
162                let highest_selected = (0..lower)
163                    .filter(|ix| coin_flipper.random_ratio_one_over(consumed + ix + 1))
164                    .last();
165
166                consumed += lower;
167                next = lower;
168
169                if let Some(ix) = highest_selected {
170                    result = self.nth(ix);
171                    next -= ix + 1;
172                    debug_assert!(result.is_some(), "iterator shorter than size_hint().0");
173                }
174            }
175
176            let elem = self.nth(next);
177            if elem.is_none() {
178                return result;
179            }
180
181            if coin_flipper.random_ratio_one_over(consumed + 1) {
182                result = elem;
183            }
184            consumed += 1;
185        }
186    }
187
188    /// Uniformly sample `amount` distinct elements into a buffer
189    ///
190    /// Collects values at random from the iterator into a supplied buffer
191    /// until that buffer is filled.
192    ///
193    /// Although the elements are selected randomly, the order of elements in
194    /// the buffer is neither stable nor fully random. If random ordering is
195    /// desired, shuffle the result.
196    ///
197    /// Returns the number of elements added to the buffer. This equals the length
198    /// of the buffer unless the iterator contains insufficient elements, in which
199    /// case this equals the number of elements available.
200    ///
201    /// Complexity is `O(n)` where `n` is the length of the iterator.
202    /// For slices, prefer [`IndexedRandom::choose_multiple`].
203    fn choose_multiple_fill<R>(mut self, rng: &mut R, buf: &mut [Self::Item]) -> usize
204    where
205        R: Rng + ?Sized,
206    {
207        let amount = buf.len();
208        let mut len = 0;
209        while len < amount {
210            if let Some(elem) = self.next() {
211                buf[len] = elem;
212                len += 1;
213            } else {
214                // Iterator exhausted; stop early
215                return len;
216            }
217        }
218
219        // Continue, since the iterator was not exhausted
220        for (i, elem) in self.enumerate() {
221            let k = rng.random_range(..i + 1 + amount);
222            if let Some(slot) = buf.get_mut(k) {
223                *slot = elem;
224            }
225        }
226        len
227    }
228
229    /// Uniformly sample `amount` distinct elements into a [`Vec`]
230    ///
231    /// This is equivalent to `choose_multiple_fill` except for the result type.
232    ///
233    /// Although the elements are selected randomly, the order of elements in
234    /// the buffer is neither stable nor fully random. If random ordering is
235    /// desired, shuffle the result.
236    ///
237    /// The length of the returned vector equals `amount` unless the iterator
238    /// contains insufficient elements, in which case it equals the number of
239    /// elements available.
240    ///
241    /// Complexity is `O(n)` where `n` is the length of the iterator.
242    /// For slices, prefer [`IndexedRandom::choose_multiple`].
243    #[cfg(feature = "alloc")]
244    fn choose_multiple<R>(mut self, rng: &mut R, amount: usize) -> Vec<Self::Item>
245    where
246        R: Rng + ?Sized,
247    {
248        let mut reservoir = Vec::with_capacity(amount);
249        reservoir.extend(self.by_ref().take(amount));
250
251        // Continue unless the iterator was exhausted
252        //
253        // note: this prevents iterators that "restart" from causing problems.
254        // If the iterator stops once, then so do we.
255        if reservoir.len() == amount {
256            for (i, elem) in self.enumerate() {
257                let k = rng.random_range(..i + 1 + amount);
258                if let Some(slot) = reservoir.get_mut(k) {
259                    *slot = elem;
260                }
261            }
262        } else {
263            // Don't hang onto extra memory. There is a corner case where
264            // `amount` was much less than `self.len()`.
265            reservoir.shrink_to_fit();
266        }
267        reservoir
268    }
269}
270
271impl<I> IteratorRandom for I where I: Iterator + Sized {}
272
273#[cfg(test)]
274mod test {
275    use super::*;
276    #[cfg(all(feature = "alloc", not(feature = "std")))]
277    use alloc::vec::Vec;
278
279    #[derive(Clone)]
280    struct UnhintedIterator<I: Iterator + Clone> {
281        iter: I,
282    }
283    impl<I: Iterator + Clone> Iterator for UnhintedIterator<I> {
284        type Item = I::Item;
285
286        fn next(&mut self) -> Option<Self::Item> {
287            self.iter.next()
288        }
289    }
290
291    #[derive(Clone)]
292    struct ChunkHintedIterator<I: ExactSizeIterator + Iterator + Clone> {
293        iter: I,
294        chunk_remaining: usize,
295        chunk_size: usize,
296        hint_total_size: bool,
297    }
298    impl<I: ExactSizeIterator + Iterator + Clone> Iterator for ChunkHintedIterator<I> {
299        type Item = I::Item;
300
301        fn next(&mut self) -> Option<Self::Item> {
302            if self.chunk_remaining == 0 {
303                self.chunk_remaining = core::cmp::min(self.chunk_size, self.iter.len());
304            }
305            self.chunk_remaining = self.chunk_remaining.saturating_sub(1);
306
307            self.iter.next()
308        }
309
310        fn size_hint(&self) -> (usize, Option<usize>) {
311            (
312                self.chunk_remaining,
313                if self.hint_total_size {
314                    Some(self.iter.len())
315                } else {
316                    None
317                },
318            )
319        }
320    }
321
322    #[derive(Clone)]
323    struct WindowHintedIterator<I: ExactSizeIterator + Iterator + Clone> {
324        iter: I,
325        window_size: usize,
326        hint_total_size: bool,
327    }
328    impl<I: ExactSizeIterator + Iterator + Clone> Iterator for WindowHintedIterator<I> {
329        type Item = I::Item;
330
331        fn next(&mut self) -> Option<Self::Item> {
332            self.iter.next()
333        }
334
335        fn size_hint(&self) -> (usize, Option<usize>) {
336            (
337                core::cmp::min(self.iter.len(), self.window_size),
338                if self.hint_total_size {
339                    Some(self.iter.len())
340                } else {
341                    None
342                },
343            )
344        }
345    }
346
347    #[test]
348    #[cfg_attr(miri, ignore)] // Miri is too slow
349    fn test_iterator_choose() {
350        let r = &mut crate::test::rng(109);
351        fn test_iter<R: Rng + ?Sized, Iter: Iterator<Item = usize> + Clone>(r: &mut R, iter: Iter) {
352            let mut chosen = [0i32; 9];
353            for _ in 0..1000 {
354                let picked = iter.clone().choose(r).unwrap();
355                chosen[picked] += 1;
356            }
357            for count in chosen.iter() {
358                // Samples should follow Binomial(1000, 1/9)
359                // Octave: binopdf(x, 1000, 1/9) gives the prob of *count == x
360                // Note: have seen 153, which is unlikely but not impossible.
361                assert!(
362                    72 < *count && *count < 154,
363                    "count not close to 1000/9: {}",
364                    count
365                );
366            }
367        }
368
369        test_iter(r, 0..9);
370        test_iter(r, [0, 1, 2, 3, 4, 5, 6, 7, 8].iter().cloned());
371        #[cfg(feature = "alloc")]
372        test_iter(r, (0..9).collect::<Vec<_>>().into_iter());
373        test_iter(r, UnhintedIterator { iter: 0..9 });
374        test_iter(
375            r,
376            ChunkHintedIterator {
377                iter: 0..9,
378                chunk_size: 4,
379                chunk_remaining: 4,
380                hint_total_size: false,
381            },
382        );
383        test_iter(
384            r,
385            ChunkHintedIterator {
386                iter: 0..9,
387                chunk_size: 4,
388                chunk_remaining: 4,
389                hint_total_size: true,
390            },
391        );
392        test_iter(
393            r,
394            WindowHintedIterator {
395                iter: 0..9,
396                window_size: 2,
397                hint_total_size: false,
398            },
399        );
400        test_iter(
401            r,
402            WindowHintedIterator {
403                iter: 0..9,
404                window_size: 2,
405                hint_total_size: true,
406            },
407        );
408
409        assert_eq!((0..0).choose(r), None);
410        assert_eq!(UnhintedIterator { iter: 0..0 }.choose(r), None);
411    }
412
413    #[test]
414    #[cfg_attr(miri, ignore)] // Miri is too slow
415    fn test_iterator_choose_stable() {
416        let r = &mut crate::test::rng(109);
417        fn test_iter<R: Rng + ?Sized, Iter: Iterator<Item = usize> + Clone>(r: &mut R, iter: Iter) {
418            let mut chosen = [0i32; 9];
419            for _ in 0..1000 {
420                let picked = iter.clone().choose_stable(r).unwrap();
421                chosen[picked] += 1;
422            }
423            for count in chosen.iter() {
424                // Samples should follow Binomial(1000, 1/9)
425                // Octave: binopdf(x, 1000, 1/9) gives the prob of *count == x
426                // Note: have seen 153, which is unlikely but not impossible.
427                assert!(
428                    72 < *count && *count < 154,
429                    "count not close to 1000/9: {}",
430                    count
431                );
432            }
433        }
434
435        test_iter(r, 0..9);
436        test_iter(r, [0, 1, 2, 3, 4, 5, 6, 7, 8].iter().cloned());
437        #[cfg(feature = "alloc")]
438        test_iter(r, (0..9).collect::<Vec<_>>().into_iter());
439        test_iter(r, UnhintedIterator { iter: 0..9 });
440        test_iter(
441            r,
442            ChunkHintedIterator {
443                iter: 0..9,
444                chunk_size: 4,
445                chunk_remaining: 4,
446                hint_total_size: false,
447            },
448        );
449        test_iter(
450            r,
451            ChunkHintedIterator {
452                iter: 0..9,
453                chunk_size: 4,
454                chunk_remaining: 4,
455                hint_total_size: true,
456            },
457        );
458        test_iter(
459            r,
460            WindowHintedIterator {
461                iter: 0..9,
462                window_size: 2,
463                hint_total_size: false,
464            },
465        );
466        test_iter(
467            r,
468            WindowHintedIterator {
469                iter: 0..9,
470                window_size: 2,
471                hint_total_size: true,
472            },
473        );
474
475        assert_eq!((0..0).choose(r), None);
476        assert_eq!(UnhintedIterator { iter: 0..0 }.choose(r), None);
477    }
478
479    #[test]
480    #[cfg_attr(miri, ignore)] // Miri is too slow
481    fn test_iterator_choose_stable_stability() {
482        fn test_iter(iter: impl Iterator<Item = usize> + Clone) -> [i32; 9] {
483            let r = &mut crate::test::rng(109);
484            let mut chosen = [0i32; 9];
485            for _ in 0..1000 {
486                let picked = iter.clone().choose_stable(r).unwrap();
487                chosen[picked] += 1;
488            }
489            chosen
490        }
491
492        let reference = test_iter(0..9);
493        assert_eq!(
494            test_iter([0, 1, 2, 3, 4, 5, 6, 7, 8].iter().cloned()),
495            reference
496        );
497
498        #[cfg(feature = "alloc")]
499        assert_eq!(test_iter((0..9).collect::<Vec<_>>().into_iter()), reference);
500        assert_eq!(test_iter(UnhintedIterator { iter: 0..9 }), reference);
501        assert_eq!(
502            test_iter(ChunkHintedIterator {
503                iter: 0..9,
504                chunk_size: 4,
505                chunk_remaining: 4,
506                hint_total_size: false,
507            }),
508            reference
509        );
510        assert_eq!(
511            test_iter(ChunkHintedIterator {
512                iter: 0..9,
513                chunk_size: 4,
514                chunk_remaining: 4,
515                hint_total_size: true,
516            }),
517            reference
518        );
519        assert_eq!(
520            test_iter(WindowHintedIterator {
521                iter: 0..9,
522                window_size: 2,
523                hint_total_size: false,
524            }),
525            reference
526        );
527        assert_eq!(
528            test_iter(WindowHintedIterator {
529                iter: 0..9,
530                window_size: 2,
531                hint_total_size: true,
532            }),
533            reference
534        );
535    }
536
537    #[test]
538    #[cfg(feature = "alloc")]
539    fn test_sample_iter() {
540        let min_val = 1;
541        let max_val = 100;
542
543        let mut r = crate::test::rng(401);
544        let vals = (min_val..max_val).collect::<Vec<i32>>();
545        let small_sample = vals.iter().choose_multiple(&mut r, 5);
546        let large_sample = vals.iter().choose_multiple(&mut r, vals.len() + 5);
547
548        assert_eq!(small_sample.len(), 5);
549        assert_eq!(large_sample.len(), vals.len());
550        // no randomization happens when amount >= len
551        assert_eq!(large_sample, vals.iter().collect::<Vec<_>>());
552
553        assert!(small_sample
554            .iter()
555            .all(|e| { **e >= min_val && **e <= max_val }));
556    }
557
558    #[test]
559    fn value_stability_choose() {
560        fn choose<I: Iterator<Item = u32>>(iter: I) -> Option<u32> {
561            let mut rng = crate::test::rng(411);
562            iter.choose(&mut rng)
563        }
564
565        assert_eq!(choose([].iter().cloned()), None);
566        assert_eq!(choose(0..100), Some(33));
567        assert_eq!(choose(UnhintedIterator { iter: 0..100 }), Some(27));
568        assert_eq!(
569            choose(ChunkHintedIterator {
570                iter: 0..100,
571                chunk_size: 32,
572                chunk_remaining: 32,
573                hint_total_size: false,
574            }),
575            Some(91)
576        );
577        assert_eq!(
578            choose(ChunkHintedIterator {
579                iter: 0..100,
580                chunk_size: 32,
581                chunk_remaining: 32,
582                hint_total_size: true,
583            }),
584            Some(91)
585        );
586        assert_eq!(
587            choose(WindowHintedIterator {
588                iter: 0..100,
589                window_size: 32,
590                hint_total_size: false,
591            }),
592            Some(34)
593        );
594        assert_eq!(
595            choose(WindowHintedIterator {
596                iter: 0..100,
597                window_size: 32,
598                hint_total_size: true,
599            }),
600            Some(34)
601        );
602    }
603
604    #[test]
605    fn value_stability_choose_stable() {
606        fn choose<I: Iterator<Item = u32>>(iter: I) -> Option<u32> {
607            let mut rng = crate::test::rng(411);
608            iter.choose_stable(&mut rng)
609        }
610
611        assert_eq!(choose([].iter().cloned()), None);
612        assert_eq!(choose(0..100), Some(27));
613        assert_eq!(choose(UnhintedIterator { iter: 0..100 }), Some(27));
614        assert_eq!(
615            choose(ChunkHintedIterator {
616                iter: 0..100,
617                chunk_size: 32,
618                chunk_remaining: 32,
619                hint_total_size: false,
620            }),
621            Some(27)
622        );
623        assert_eq!(
624            choose(ChunkHintedIterator {
625                iter: 0..100,
626                chunk_size: 32,
627                chunk_remaining: 32,
628                hint_total_size: true,
629            }),
630            Some(27)
631        );
632        assert_eq!(
633            choose(WindowHintedIterator {
634                iter: 0..100,
635                window_size: 32,
636                hint_total_size: false,
637            }),
638            Some(27)
639        );
640        assert_eq!(
641            choose(WindowHintedIterator {
642                iter: 0..100,
643                window_size: 32,
644                hint_total_size: true,
645            }),
646            Some(27)
647        );
648    }
649
650    #[test]
651    fn value_stability_choose_multiple() {
652        fn do_test<I: Clone + Iterator<Item = u32>>(iter: I, v: &[u32]) {
653            let mut rng = crate::test::rng(412);
654            let mut buf = [0u32; 8];
655            assert_eq!(
656                iter.clone().choose_multiple_fill(&mut rng, &mut buf),
657                v.len()
658            );
659            assert_eq!(&buf[0..v.len()], v);
660
661            #[cfg(feature = "alloc")]
662            {
663                let mut rng = crate::test::rng(412);
664                assert_eq!(iter.choose_multiple(&mut rng, v.len()), v);
665            }
666        }
667
668        do_test(0..4, &[0, 1, 2, 3]);
669        do_test(0..8, &[0, 1, 2, 3, 4, 5, 6, 7]);
670        do_test(0..100, &[77, 95, 38, 23, 25, 8, 58, 40]);
671    }
672}