1use super::coin_flipper::CoinFlipper;
12#[allow(unused)]
13use super::IndexedRandom;
14use crate::Rng;
15#[cfg(feature = "alloc")]
16use alloc::vec::Vec;
17
18pub trait IteratorRandom: Iterator + Sized {
35 fn choose<R>(mut self, rng: &mut R) -> Option<Self::Item>
67 where
68 R: Rng + ?Sized,
69 {
70 let (mut lower, mut upper) = self.size_hint();
71 let mut result = None;
72
73 if upper == Some(lower) {
77 return match lower {
78 0 => None,
79 1 => self.next(),
80 _ => self.nth(rng.random_range(..lower)),
81 };
82 }
83
84 let mut coin_flipper = CoinFlipper::new(rng);
85 let mut consumed = 0;
86
87 loop {
89 if lower > 1 {
90 let ix = coin_flipper.rng.random_range(..lower + consumed);
91 let skip = if ix < lower {
92 result = self.nth(ix);
93 lower - (ix + 1)
94 } else {
95 lower
96 };
97 if upper == Some(lower) {
98 return result;
99 }
100 consumed += lower;
101 if skip > 0 {
102 self.nth(skip - 1);
103 }
104 } else {
105 let elem = self.next();
106 if elem.is_none() {
107 return result;
108 }
109 consumed += 1;
110 if coin_flipper.random_ratio_one_over(consumed) {
111 result = elem;
112 }
113 }
114
115 let hint = self.size_hint();
116 lower = hint.0;
117 upper = hint.1;
118 }
119 }
120
121 #[allow(unknown_lints)]
144 #[allow(clippy::double_ended_iterator_last)]
145 fn choose_stable<R>(mut self, rng: &mut R) -> Option<Self::Item>
146 where
147 R: Rng + ?Sized,
148 {
149 let mut consumed = 0;
150 let mut result = None;
151 let mut coin_flipper = CoinFlipper::new(rng);
152
153 loop {
154 let mut next = 0;
159
160 let (lower, _) = self.size_hint();
161 if lower >= 2 {
162 let highest_selected = (0..lower)
163 .filter(|ix| coin_flipper.random_ratio_one_over(consumed + ix + 1))
164 .last();
165
166 consumed += lower;
167 next = lower;
168
169 if let Some(ix) = highest_selected {
170 result = self.nth(ix);
171 next -= ix + 1;
172 debug_assert!(result.is_some(), "iterator shorter than size_hint().0");
173 }
174 }
175
176 let elem = self.nth(next);
177 if elem.is_none() {
178 return result;
179 }
180
181 if coin_flipper.random_ratio_one_over(consumed + 1) {
182 result = elem;
183 }
184 consumed += 1;
185 }
186 }
187
188 fn choose_multiple_fill<R>(mut self, rng: &mut R, buf: &mut [Self::Item]) -> usize
204 where
205 R: Rng + ?Sized,
206 {
207 let amount = buf.len();
208 let mut len = 0;
209 while len < amount {
210 if let Some(elem) = self.next() {
211 buf[len] = elem;
212 len += 1;
213 } else {
214 return len;
216 }
217 }
218
219 for (i, elem) in self.enumerate() {
221 let k = rng.random_range(..i + 1 + amount);
222 if let Some(slot) = buf.get_mut(k) {
223 *slot = elem;
224 }
225 }
226 len
227 }
228
229 #[cfg(feature = "alloc")]
244 fn choose_multiple<R>(mut self, rng: &mut R, amount: usize) -> Vec<Self::Item>
245 where
246 R: Rng + ?Sized,
247 {
248 let mut reservoir = Vec::with_capacity(amount);
249 reservoir.extend(self.by_ref().take(amount));
250
251 if reservoir.len() == amount {
256 for (i, elem) in self.enumerate() {
257 let k = rng.random_range(..i + 1 + amount);
258 if let Some(slot) = reservoir.get_mut(k) {
259 *slot = elem;
260 }
261 }
262 } else {
263 reservoir.shrink_to_fit();
266 }
267 reservoir
268 }
269}
270
271impl<I> IteratorRandom for I where I: Iterator + Sized {}
272
273#[cfg(test)]
274mod test {
275 use super::*;
276 #[cfg(all(feature = "alloc", not(feature = "std")))]
277 use alloc::vec::Vec;
278
279 #[derive(Clone)]
280 struct UnhintedIterator<I: Iterator + Clone> {
281 iter: I,
282 }
283 impl<I: Iterator + Clone> Iterator for UnhintedIterator<I> {
284 type Item = I::Item;
285
286 fn next(&mut self) -> Option<Self::Item> {
287 self.iter.next()
288 }
289 }
290
291 #[derive(Clone)]
292 struct ChunkHintedIterator<I: ExactSizeIterator + Iterator + Clone> {
293 iter: I,
294 chunk_remaining: usize,
295 chunk_size: usize,
296 hint_total_size: bool,
297 }
298 impl<I: ExactSizeIterator + Iterator + Clone> Iterator for ChunkHintedIterator<I> {
299 type Item = I::Item;
300
301 fn next(&mut self) -> Option<Self::Item> {
302 if self.chunk_remaining == 0 {
303 self.chunk_remaining = core::cmp::min(self.chunk_size, self.iter.len());
304 }
305 self.chunk_remaining = self.chunk_remaining.saturating_sub(1);
306
307 self.iter.next()
308 }
309
310 fn size_hint(&self) -> (usize, Option<usize>) {
311 (
312 self.chunk_remaining,
313 if self.hint_total_size {
314 Some(self.iter.len())
315 } else {
316 None
317 },
318 )
319 }
320 }
321
322 #[derive(Clone)]
323 struct WindowHintedIterator<I: ExactSizeIterator + Iterator + Clone> {
324 iter: I,
325 window_size: usize,
326 hint_total_size: bool,
327 }
328 impl<I: ExactSizeIterator + Iterator + Clone> Iterator for WindowHintedIterator<I> {
329 type Item = I::Item;
330
331 fn next(&mut self) -> Option<Self::Item> {
332 self.iter.next()
333 }
334
335 fn size_hint(&self) -> (usize, Option<usize>) {
336 (
337 core::cmp::min(self.iter.len(), self.window_size),
338 if self.hint_total_size {
339 Some(self.iter.len())
340 } else {
341 None
342 },
343 )
344 }
345 }
346
347 #[test]
348 #[cfg_attr(miri, ignore)] fn test_iterator_choose() {
350 let r = &mut crate::test::rng(109);
351 fn test_iter<R: Rng + ?Sized, Iter: Iterator<Item = usize> + Clone>(r: &mut R, iter: Iter) {
352 let mut chosen = [0i32; 9];
353 for _ in 0..1000 {
354 let picked = iter.clone().choose(r).unwrap();
355 chosen[picked] += 1;
356 }
357 for count in chosen.iter() {
358 assert!(
362 72 < *count && *count < 154,
363 "count not close to 1000/9: {}",
364 count
365 );
366 }
367 }
368
369 test_iter(r, 0..9);
370 test_iter(r, [0, 1, 2, 3, 4, 5, 6, 7, 8].iter().cloned());
371 #[cfg(feature = "alloc")]
372 test_iter(r, (0..9).collect::<Vec<_>>().into_iter());
373 test_iter(r, UnhintedIterator { iter: 0..9 });
374 test_iter(
375 r,
376 ChunkHintedIterator {
377 iter: 0..9,
378 chunk_size: 4,
379 chunk_remaining: 4,
380 hint_total_size: false,
381 },
382 );
383 test_iter(
384 r,
385 ChunkHintedIterator {
386 iter: 0..9,
387 chunk_size: 4,
388 chunk_remaining: 4,
389 hint_total_size: true,
390 },
391 );
392 test_iter(
393 r,
394 WindowHintedIterator {
395 iter: 0..9,
396 window_size: 2,
397 hint_total_size: false,
398 },
399 );
400 test_iter(
401 r,
402 WindowHintedIterator {
403 iter: 0..9,
404 window_size: 2,
405 hint_total_size: true,
406 },
407 );
408
409 assert_eq!((0..0).choose(r), None);
410 assert_eq!(UnhintedIterator { iter: 0..0 }.choose(r), None);
411 }
412
413 #[test]
414 #[cfg_attr(miri, ignore)] fn test_iterator_choose_stable() {
416 let r = &mut crate::test::rng(109);
417 fn test_iter<R: Rng + ?Sized, Iter: Iterator<Item = usize> + Clone>(r: &mut R, iter: Iter) {
418 let mut chosen = [0i32; 9];
419 for _ in 0..1000 {
420 let picked = iter.clone().choose_stable(r).unwrap();
421 chosen[picked] += 1;
422 }
423 for count in chosen.iter() {
424 assert!(
428 72 < *count && *count < 154,
429 "count not close to 1000/9: {}",
430 count
431 );
432 }
433 }
434
435 test_iter(r, 0..9);
436 test_iter(r, [0, 1, 2, 3, 4, 5, 6, 7, 8].iter().cloned());
437 #[cfg(feature = "alloc")]
438 test_iter(r, (0..9).collect::<Vec<_>>().into_iter());
439 test_iter(r, UnhintedIterator { iter: 0..9 });
440 test_iter(
441 r,
442 ChunkHintedIterator {
443 iter: 0..9,
444 chunk_size: 4,
445 chunk_remaining: 4,
446 hint_total_size: false,
447 },
448 );
449 test_iter(
450 r,
451 ChunkHintedIterator {
452 iter: 0..9,
453 chunk_size: 4,
454 chunk_remaining: 4,
455 hint_total_size: true,
456 },
457 );
458 test_iter(
459 r,
460 WindowHintedIterator {
461 iter: 0..9,
462 window_size: 2,
463 hint_total_size: false,
464 },
465 );
466 test_iter(
467 r,
468 WindowHintedIterator {
469 iter: 0..9,
470 window_size: 2,
471 hint_total_size: true,
472 },
473 );
474
475 assert_eq!((0..0).choose(r), None);
476 assert_eq!(UnhintedIterator { iter: 0..0 }.choose(r), None);
477 }
478
479 #[test]
480 #[cfg_attr(miri, ignore)] fn test_iterator_choose_stable_stability() {
482 fn test_iter(iter: impl Iterator<Item = usize> + Clone) -> [i32; 9] {
483 let r = &mut crate::test::rng(109);
484 let mut chosen = [0i32; 9];
485 for _ in 0..1000 {
486 let picked = iter.clone().choose_stable(r).unwrap();
487 chosen[picked] += 1;
488 }
489 chosen
490 }
491
492 let reference = test_iter(0..9);
493 assert_eq!(
494 test_iter([0, 1, 2, 3, 4, 5, 6, 7, 8].iter().cloned()),
495 reference
496 );
497
498 #[cfg(feature = "alloc")]
499 assert_eq!(test_iter((0..9).collect::<Vec<_>>().into_iter()), reference);
500 assert_eq!(test_iter(UnhintedIterator { iter: 0..9 }), reference);
501 assert_eq!(
502 test_iter(ChunkHintedIterator {
503 iter: 0..9,
504 chunk_size: 4,
505 chunk_remaining: 4,
506 hint_total_size: false,
507 }),
508 reference
509 );
510 assert_eq!(
511 test_iter(ChunkHintedIterator {
512 iter: 0..9,
513 chunk_size: 4,
514 chunk_remaining: 4,
515 hint_total_size: true,
516 }),
517 reference
518 );
519 assert_eq!(
520 test_iter(WindowHintedIterator {
521 iter: 0..9,
522 window_size: 2,
523 hint_total_size: false,
524 }),
525 reference
526 );
527 assert_eq!(
528 test_iter(WindowHintedIterator {
529 iter: 0..9,
530 window_size: 2,
531 hint_total_size: true,
532 }),
533 reference
534 );
535 }
536
537 #[test]
538 #[cfg(feature = "alloc")]
539 fn test_sample_iter() {
540 let min_val = 1;
541 let max_val = 100;
542
543 let mut r = crate::test::rng(401);
544 let vals = (min_val..max_val).collect::<Vec<i32>>();
545 let small_sample = vals.iter().choose_multiple(&mut r, 5);
546 let large_sample = vals.iter().choose_multiple(&mut r, vals.len() + 5);
547
548 assert_eq!(small_sample.len(), 5);
549 assert_eq!(large_sample.len(), vals.len());
550 assert_eq!(large_sample, vals.iter().collect::<Vec<_>>());
552
553 assert!(small_sample
554 .iter()
555 .all(|e| { **e >= min_val && **e <= max_val }));
556 }
557
558 #[test]
559 fn value_stability_choose() {
560 fn choose<I: Iterator<Item = u32>>(iter: I) -> Option<u32> {
561 let mut rng = crate::test::rng(411);
562 iter.choose(&mut rng)
563 }
564
565 assert_eq!(choose([].iter().cloned()), None);
566 assert_eq!(choose(0..100), Some(33));
567 assert_eq!(choose(UnhintedIterator { iter: 0..100 }), Some(27));
568 assert_eq!(
569 choose(ChunkHintedIterator {
570 iter: 0..100,
571 chunk_size: 32,
572 chunk_remaining: 32,
573 hint_total_size: false,
574 }),
575 Some(91)
576 );
577 assert_eq!(
578 choose(ChunkHintedIterator {
579 iter: 0..100,
580 chunk_size: 32,
581 chunk_remaining: 32,
582 hint_total_size: true,
583 }),
584 Some(91)
585 );
586 assert_eq!(
587 choose(WindowHintedIterator {
588 iter: 0..100,
589 window_size: 32,
590 hint_total_size: false,
591 }),
592 Some(34)
593 );
594 assert_eq!(
595 choose(WindowHintedIterator {
596 iter: 0..100,
597 window_size: 32,
598 hint_total_size: true,
599 }),
600 Some(34)
601 );
602 }
603
604 #[test]
605 fn value_stability_choose_stable() {
606 fn choose<I: Iterator<Item = u32>>(iter: I) -> Option<u32> {
607 let mut rng = crate::test::rng(411);
608 iter.choose_stable(&mut rng)
609 }
610
611 assert_eq!(choose([].iter().cloned()), None);
612 assert_eq!(choose(0..100), Some(27));
613 assert_eq!(choose(UnhintedIterator { iter: 0..100 }), Some(27));
614 assert_eq!(
615 choose(ChunkHintedIterator {
616 iter: 0..100,
617 chunk_size: 32,
618 chunk_remaining: 32,
619 hint_total_size: false,
620 }),
621 Some(27)
622 );
623 assert_eq!(
624 choose(ChunkHintedIterator {
625 iter: 0..100,
626 chunk_size: 32,
627 chunk_remaining: 32,
628 hint_total_size: true,
629 }),
630 Some(27)
631 );
632 assert_eq!(
633 choose(WindowHintedIterator {
634 iter: 0..100,
635 window_size: 32,
636 hint_total_size: false,
637 }),
638 Some(27)
639 );
640 assert_eq!(
641 choose(WindowHintedIterator {
642 iter: 0..100,
643 window_size: 32,
644 hint_total_size: true,
645 }),
646 Some(27)
647 );
648 }
649
650 #[test]
651 fn value_stability_choose_multiple() {
652 fn do_test<I: Clone + Iterator<Item = u32>>(iter: I, v: &[u32]) {
653 let mut rng = crate::test::rng(412);
654 let mut buf = [0u32; 8];
655 assert_eq!(
656 iter.clone().choose_multiple_fill(&mut rng, &mut buf),
657 v.len()
658 );
659 assert_eq!(&buf[0..v.len()], v);
660
661 #[cfg(feature = "alloc")]
662 {
663 let mut rng = crate::test::rng(412);
664 assert_eq!(iter.choose_multiple(&mut rng, v.len()), v);
665 }
666 }
667
668 do_test(0..4, &[0, 1, 2, 3]);
669 do_test(0..8, &[0, 1, 2, 3, 4, 5, 6, 7]);
670 do_test(0..100, &[77, 95, 38, 23, 25, 8, 58, 40]);
671 }
672}