1use super::coin_flipper::CoinFlipper;
12#[allow(unused)]
13use super::IndexedRandom;
14use crate::Rng;
15#[cfg(feature = "alloc")]
16use alloc::vec::Vec;
17
18pub trait IteratorRandom: Iterator + Sized {
35 fn choose<R>(mut self, rng: &mut R) -> Option<Self::Item>
67 where
68 R: Rng + ?Sized,
69 {
70 let (mut lower, mut upper) = self.size_hint();
71 let mut result = None;
72
73 if upper == Some(lower) {
77 return match lower {
78 0 => None,
79 1 => self.next(),
80 _ => self.nth(rng.random_range(..lower)),
81 };
82 }
83
84 let mut coin_flipper = CoinFlipper::new(rng);
85 let mut consumed = 0;
86
87 loop {
89 if lower > 1 {
90 let ix = coin_flipper.rng.random_range(..lower + consumed);
91 let skip = if ix < lower {
92 result = self.nth(ix);
93 lower - (ix + 1)
94 } else {
95 lower
96 };
97 if upper == Some(lower) {
98 return result;
99 }
100 consumed += lower;
101 if skip > 0 {
102 self.nth(skip - 1);
103 }
104 } else {
105 let elem = self.next();
106 if elem.is_none() {
107 return result;
108 }
109 consumed += 1;
110 if coin_flipper.random_ratio_one_over(consumed) {
111 result = elem;
112 }
113 }
114
115 let hint = self.size_hint();
116 lower = hint.0;
117 upper = hint.1;
118 }
119 }
120
121 #[allow(clippy::double_ended_iterator_last)]
141 fn choose_stable<R>(mut self, rng: &mut R) -> Option<Self::Item>
142 where
143 R: Rng + ?Sized,
144 {
145 let mut consumed = 0;
146 let mut result = None;
147 let mut coin_flipper = CoinFlipper::new(rng);
148
149 loop {
150 let mut next = 0;
155
156 let (lower, _) = self.size_hint();
157 if lower >= 2 {
158 let highest_selected = (0..lower)
159 .filter(|ix| coin_flipper.random_ratio_one_over(consumed + ix + 1))
160 .last();
161
162 consumed += lower;
163 next = lower;
164
165 if let Some(ix) = highest_selected {
166 result = self.nth(ix);
167 next -= ix + 1;
168 debug_assert!(result.is_some(), "iterator shorter than size_hint().0");
169 }
170 }
171
172 let elem = self.nth(next);
173 if elem.is_none() {
174 return result;
175 }
176
177 if coin_flipper.random_ratio_one_over(consumed + 1) {
178 result = elem;
179 }
180 consumed += 1;
181 }
182 }
183
184 fn choose_multiple_fill<R>(mut self, rng: &mut R, buf: &mut [Self::Item]) -> usize
200 where
201 R: Rng + ?Sized,
202 {
203 let amount = buf.len();
204 let mut len = 0;
205 while len < amount {
206 if let Some(elem) = self.next() {
207 buf[len] = elem;
208 len += 1;
209 } else {
210 return len;
212 }
213 }
214
215 for (i, elem) in self.enumerate() {
217 let k = rng.random_range(..i + 1 + amount);
218 if let Some(slot) = buf.get_mut(k) {
219 *slot = elem;
220 }
221 }
222 len
223 }
224
225 #[cfg(feature = "alloc")]
240 fn choose_multiple<R>(mut self, rng: &mut R, amount: usize) -> Vec<Self::Item>
241 where
242 R: Rng + ?Sized,
243 {
244 let mut reservoir = Vec::with_capacity(amount);
245 reservoir.extend(self.by_ref().take(amount));
246
247 if reservoir.len() == amount {
252 for (i, elem) in self.enumerate() {
253 let k = rng.random_range(..i + 1 + amount);
254 if let Some(slot) = reservoir.get_mut(k) {
255 *slot = elem;
256 }
257 }
258 } else {
259 reservoir.shrink_to_fit();
262 }
263 reservoir
264 }
265}
266
267impl<I> IteratorRandom for I where I: Iterator + Sized {}
268
269#[cfg(test)]
270mod test {
271 use super::*;
272 #[cfg(all(feature = "alloc", not(feature = "std")))]
273 use alloc::vec::Vec;
274
275 #[derive(Clone)]
276 struct UnhintedIterator<I: Iterator + Clone> {
277 iter: I,
278 }
279 impl<I: Iterator + Clone> Iterator for UnhintedIterator<I> {
280 type Item = I::Item;
281
282 fn next(&mut self) -> Option<Self::Item> {
283 self.iter.next()
284 }
285 }
286
287 #[derive(Clone)]
288 struct ChunkHintedIterator<I: ExactSizeIterator + Iterator + Clone> {
289 iter: I,
290 chunk_remaining: usize,
291 chunk_size: usize,
292 hint_total_size: bool,
293 }
294 impl<I: ExactSizeIterator + Iterator + Clone> Iterator for ChunkHintedIterator<I> {
295 type Item = I::Item;
296
297 fn next(&mut self) -> Option<Self::Item> {
298 if self.chunk_remaining == 0 {
299 self.chunk_remaining = core::cmp::min(self.chunk_size, self.iter.len());
300 }
301 self.chunk_remaining = self.chunk_remaining.saturating_sub(1);
302
303 self.iter.next()
304 }
305
306 fn size_hint(&self) -> (usize, Option<usize>) {
307 (
308 self.chunk_remaining,
309 if self.hint_total_size {
310 Some(self.iter.len())
311 } else {
312 None
313 },
314 )
315 }
316 }
317
318 #[derive(Clone)]
319 struct WindowHintedIterator<I: ExactSizeIterator + Iterator + Clone> {
320 iter: I,
321 window_size: usize,
322 hint_total_size: bool,
323 }
324 impl<I: ExactSizeIterator + Iterator + Clone> Iterator for WindowHintedIterator<I> {
325 type Item = I::Item;
326
327 fn next(&mut self) -> Option<Self::Item> {
328 self.iter.next()
329 }
330
331 fn size_hint(&self) -> (usize, Option<usize>) {
332 (
333 core::cmp::min(self.iter.len(), self.window_size),
334 if self.hint_total_size {
335 Some(self.iter.len())
336 } else {
337 None
338 },
339 )
340 }
341 }
342
343 #[test]
344 #[cfg_attr(miri, ignore)] fn test_iterator_choose() {
346 let r = &mut crate::test::rng(109);
347 fn test_iter<R: Rng + ?Sized, Iter: Iterator<Item = usize> + Clone>(r: &mut R, iter: Iter) {
348 let mut chosen = [0i32; 9];
349 for _ in 0..1000 {
350 let picked = iter.clone().choose(r).unwrap();
351 chosen[picked] += 1;
352 }
353 for count in chosen.iter() {
354 assert!(
358 72 < *count && *count < 154,
359 "count not close to 1000/9: {}",
360 count
361 );
362 }
363 }
364
365 test_iter(r, 0..9);
366 test_iter(r, [0, 1, 2, 3, 4, 5, 6, 7, 8].iter().cloned());
367 #[cfg(feature = "alloc")]
368 test_iter(r, (0..9).collect::<Vec<_>>().into_iter());
369 test_iter(r, UnhintedIterator { iter: 0..9 });
370 test_iter(
371 r,
372 ChunkHintedIterator {
373 iter: 0..9,
374 chunk_size: 4,
375 chunk_remaining: 4,
376 hint_total_size: false,
377 },
378 );
379 test_iter(
380 r,
381 ChunkHintedIterator {
382 iter: 0..9,
383 chunk_size: 4,
384 chunk_remaining: 4,
385 hint_total_size: true,
386 },
387 );
388 test_iter(
389 r,
390 WindowHintedIterator {
391 iter: 0..9,
392 window_size: 2,
393 hint_total_size: false,
394 },
395 );
396 test_iter(
397 r,
398 WindowHintedIterator {
399 iter: 0..9,
400 window_size: 2,
401 hint_total_size: true,
402 },
403 );
404
405 assert_eq!((0..0).choose(r), None);
406 assert_eq!(UnhintedIterator { iter: 0..0 }.choose(r), None);
407 }
408
409 #[test]
410 #[cfg_attr(miri, ignore)] fn test_iterator_choose_stable() {
412 let r = &mut crate::test::rng(109);
413 fn test_iter<R: Rng + ?Sized, Iter: Iterator<Item = usize> + Clone>(r: &mut R, iter: Iter) {
414 let mut chosen = [0i32; 9];
415 for _ in 0..1000 {
416 let picked = iter.clone().choose_stable(r).unwrap();
417 chosen[picked] += 1;
418 }
419 for count in chosen.iter() {
420 assert!(
424 72 < *count && *count < 154,
425 "count not close to 1000/9: {}",
426 count
427 );
428 }
429 }
430
431 test_iter(r, 0..9);
432 test_iter(r, [0, 1, 2, 3, 4, 5, 6, 7, 8].iter().cloned());
433 #[cfg(feature = "alloc")]
434 test_iter(r, (0..9).collect::<Vec<_>>().into_iter());
435 test_iter(r, UnhintedIterator { iter: 0..9 });
436 test_iter(
437 r,
438 ChunkHintedIterator {
439 iter: 0..9,
440 chunk_size: 4,
441 chunk_remaining: 4,
442 hint_total_size: false,
443 },
444 );
445 test_iter(
446 r,
447 ChunkHintedIterator {
448 iter: 0..9,
449 chunk_size: 4,
450 chunk_remaining: 4,
451 hint_total_size: true,
452 },
453 );
454 test_iter(
455 r,
456 WindowHintedIterator {
457 iter: 0..9,
458 window_size: 2,
459 hint_total_size: false,
460 },
461 );
462 test_iter(
463 r,
464 WindowHintedIterator {
465 iter: 0..9,
466 window_size: 2,
467 hint_total_size: true,
468 },
469 );
470
471 assert_eq!((0..0).choose(r), None);
472 assert_eq!(UnhintedIterator { iter: 0..0 }.choose(r), None);
473 }
474
475 #[test]
476 #[cfg_attr(miri, ignore)] fn test_iterator_choose_stable_stability() {
478 fn test_iter(iter: impl Iterator<Item = usize> + Clone) -> [i32; 9] {
479 let r = &mut crate::test::rng(109);
480 let mut chosen = [0i32; 9];
481 for _ in 0..1000 {
482 let picked = iter.clone().choose_stable(r).unwrap();
483 chosen[picked] += 1;
484 }
485 chosen
486 }
487
488 let reference = test_iter(0..9);
489 assert_eq!(
490 test_iter([0, 1, 2, 3, 4, 5, 6, 7, 8].iter().cloned()),
491 reference
492 );
493
494 #[cfg(feature = "alloc")]
495 assert_eq!(test_iter((0..9).collect::<Vec<_>>().into_iter()), reference);
496 assert_eq!(test_iter(UnhintedIterator { iter: 0..9 }), reference);
497 assert_eq!(
498 test_iter(ChunkHintedIterator {
499 iter: 0..9,
500 chunk_size: 4,
501 chunk_remaining: 4,
502 hint_total_size: false,
503 }),
504 reference
505 );
506 assert_eq!(
507 test_iter(ChunkHintedIterator {
508 iter: 0..9,
509 chunk_size: 4,
510 chunk_remaining: 4,
511 hint_total_size: true,
512 }),
513 reference
514 );
515 assert_eq!(
516 test_iter(WindowHintedIterator {
517 iter: 0..9,
518 window_size: 2,
519 hint_total_size: false,
520 }),
521 reference
522 );
523 assert_eq!(
524 test_iter(WindowHintedIterator {
525 iter: 0..9,
526 window_size: 2,
527 hint_total_size: true,
528 }),
529 reference
530 );
531 }
532
533 #[test]
534 #[cfg(feature = "alloc")]
535 fn test_sample_iter() {
536 let min_val = 1;
537 let max_val = 100;
538
539 let mut r = crate::test::rng(401);
540 let vals = (min_val..max_val).collect::<Vec<i32>>();
541 let small_sample = vals.iter().choose_multiple(&mut r, 5);
542 let large_sample = vals.iter().choose_multiple(&mut r, vals.len() + 5);
543
544 assert_eq!(small_sample.len(), 5);
545 assert_eq!(large_sample.len(), vals.len());
546 assert_eq!(large_sample, vals.iter().collect::<Vec<_>>());
548
549 assert!(small_sample
550 .iter()
551 .all(|e| { **e >= min_val && **e <= max_val }));
552 }
553
554 #[test]
555 fn value_stability_choose() {
556 fn choose<I: Iterator<Item = u32>>(iter: I) -> Option<u32> {
557 let mut rng = crate::test::rng(411);
558 iter.choose(&mut rng)
559 }
560
561 assert_eq!(choose([].iter().cloned()), None);
562 assert_eq!(choose(0..100), Some(33));
563 assert_eq!(choose(UnhintedIterator { iter: 0..100 }), Some(27));
564 assert_eq!(
565 choose(ChunkHintedIterator {
566 iter: 0..100,
567 chunk_size: 32,
568 chunk_remaining: 32,
569 hint_total_size: false,
570 }),
571 Some(91)
572 );
573 assert_eq!(
574 choose(ChunkHintedIterator {
575 iter: 0..100,
576 chunk_size: 32,
577 chunk_remaining: 32,
578 hint_total_size: true,
579 }),
580 Some(91)
581 );
582 assert_eq!(
583 choose(WindowHintedIterator {
584 iter: 0..100,
585 window_size: 32,
586 hint_total_size: false,
587 }),
588 Some(34)
589 );
590 assert_eq!(
591 choose(WindowHintedIterator {
592 iter: 0..100,
593 window_size: 32,
594 hint_total_size: true,
595 }),
596 Some(34)
597 );
598 }
599
600 #[test]
601 fn value_stability_choose_stable() {
602 fn choose<I: Iterator<Item = u32>>(iter: I) -> Option<u32> {
603 let mut rng = crate::test::rng(411);
604 iter.choose_stable(&mut rng)
605 }
606
607 assert_eq!(choose([].iter().cloned()), None);
608 assert_eq!(choose(0..100), Some(27));
609 assert_eq!(choose(UnhintedIterator { iter: 0..100 }), Some(27));
610 assert_eq!(
611 choose(ChunkHintedIterator {
612 iter: 0..100,
613 chunk_size: 32,
614 chunk_remaining: 32,
615 hint_total_size: false,
616 }),
617 Some(27)
618 );
619 assert_eq!(
620 choose(ChunkHintedIterator {
621 iter: 0..100,
622 chunk_size: 32,
623 chunk_remaining: 32,
624 hint_total_size: true,
625 }),
626 Some(27)
627 );
628 assert_eq!(
629 choose(WindowHintedIterator {
630 iter: 0..100,
631 window_size: 32,
632 hint_total_size: false,
633 }),
634 Some(27)
635 );
636 assert_eq!(
637 choose(WindowHintedIterator {
638 iter: 0..100,
639 window_size: 32,
640 hint_total_size: true,
641 }),
642 Some(27)
643 );
644 }
645
646 #[test]
647 fn value_stability_choose_multiple() {
648 fn do_test<I: Clone + Iterator<Item = u32>>(iter: I, v: &[u32]) {
649 let mut rng = crate::test::rng(412);
650 let mut buf = [0u32; 8];
651 assert_eq!(
652 iter.clone().choose_multiple_fill(&mut rng, &mut buf),
653 v.len()
654 );
655 assert_eq!(&buf[0..v.len()], v);
656
657 #[cfg(feature = "alloc")]
658 {
659 let mut rng = crate::test::rng(412);
660 assert_eq!(iter.choose_multiple(&mut rng, v.len()), v);
661 }
662 }
663
664 do_test(0..4, &[0, 1, 2, 3]);
665 do_test(0..8, &[0, 1, 2, 3, 4, 5, 6, 7]);
666 do_test(0..100, &[77, 95, 38, 23, 25, 8, 58, 40]);
667 }
668}