1use super::{Error, Weight};
10use crate::Rng;
11use crate::distr::Distribution;
12use crate::distr::uniform::{SampleBorrow, SampleUniform, UniformSampler};
13
14use alloc::vec::Vec;
16use core::fmt::{self, Debug};
17
18#[cfg(feature = "serde")]
19use serde::{Deserialize, Serialize};
20
21#[derive(Debug, Clone, PartialEq)]
80#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
81pub struct WeightedIndex<X: SampleUniform + PartialOrd> {
82 cumulative_weights: Vec<X>,
83 total_weight: X,
84 weight_distribution: X::Sampler,
85}
86
87impl<X: SampleUniform + PartialOrd> WeightedIndex<X> {
88 pub fn new<I>(weights: I) -> Result<WeightedIndex<X>, Error>
100 where
101 I: IntoIterator,
102 I::Item: SampleBorrow<X>,
103 X: Weight,
104 {
105 let mut iter = weights.into_iter();
106 let mut total_weight: X = iter.next().ok_or(Error::InvalidInput)?.borrow().clone();
107
108 let zero = X::ZERO;
109 if !(total_weight >= zero) {
110 return Err(Error::InvalidWeight);
111 }
112
113 let mut weights = Vec::<X>::with_capacity(iter.size_hint().0);
114 for w in iter {
115 if !(w.borrow() >= &zero) {
118 return Err(Error::InvalidWeight);
119 }
120 weights.push(total_weight.clone());
121
122 if let Err(()) = total_weight.checked_add_assign(w.borrow()) {
123 return Err(Error::Overflow);
124 }
125 }
126
127 if total_weight == zero {
128 return Err(Error::InsufficientNonZero);
129 }
130 let distr = X::Sampler::new(zero, total_weight.clone()).unwrap();
131
132 Ok(WeightedIndex {
133 cumulative_weights: weights,
134 total_weight,
135 weight_distribution: distr,
136 })
137 }
138
139 pub fn update_weights(&mut self, new_weights: &[(usize, &X)]) -> Result<(), Error>
160 where
161 X: for<'a> core::ops::AddAssign<&'a X>
162 + for<'a> core::ops::SubAssign<&'a X>
163 + Clone
164 + Default,
165 {
166 if new_weights.is_empty() {
167 return Ok(());
168 }
169
170 let zero = <X as Default>::default();
171
172 let mut total_weight = self.total_weight.clone();
173
174 let mut prev_i = None;
177 for &(i, w) in new_weights {
178 if let Some(old_i) = prev_i {
179 if old_i >= i {
180 return Err(Error::InvalidInput);
181 }
182 }
183 if !(*w >= zero) {
184 return Err(Error::InvalidWeight);
185 }
186 if i > self.cumulative_weights.len() {
187 return Err(Error::InvalidInput);
188 }
189
190 let mut old_w = if i < self.cumulative_weights.len() {
191 self.cumulative_weights[i].clone()
192 } else {
193 self.total_weight.clone()
194 };
195 if i > 0 {
196 old_w -= &self.cumulative_weights[i - 1];
197 }
198
199 total_weight -= &old_w;
200 total_weight += w;
201 prev_i = Some(i);
202 }
203 if total_weight <= zero {
204 return Err(Error::InsufficientNonZero);
205 }
206
207 let mut iter = new_weights.iter();
210
211 let mut prev_weight = zero.clone();
212 let mut next_new_weight = iter.next();
213 let &(first_new_index, _) = next_new_weight.unwrap();
214 let mut cumulative_weight = if first_new_index > 0 {
215 self.cumulative_weights[first_new_index - 1].clone()
216 } else {
217 zero.clone()
218 };
219 for i in first_new_index..self.cumulative_weights.len() {
220 match next_new_weight {
221 Some(&(j, w)) if i == j => {
222 cumulative_weight += w;
223 next_new_weight = iter.next();
224 }
225 _ => {
226 let mut tmp = self.cumulative_weights[i].clone();
227 tmp -= &prev_weight; cumulative_weight += &tmp;
229 }
230 }
231 prev_weight = cumulative_weight.clone();
232 core::mem::swap(&mut prev_weight, &mut self.cumulative_weights[i]);
233 }
234
235 self.total_weight = total_weight;
236 self.weight_distribution = X::Sampler::new(zero, self.total_weight.clone()).unwrap();
237
238 Ok(())
239 }
240}
241
242pub struct WeightedIndexIter<'a, X: SampleUniform + PartialOrd> {
245 weighted_index: &'a WeightedIndex<X>,
246 index: usize,
247}
248
249impl<X> Debug for WeightedIndexIter<'_, X>
250where
251 X: SampleUniform + PartialOrd + Debug,
252 X::Sampler: Debug,
253{
254 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
255 f.debug_struct("WeightedIndexIter")
256 .field("weighted_index", &self.weighted_index)
257 .field("index", &self.index)
258 .finish()
259 }
260}
261
262impl<X> Clone for WeightedIndexIter<'_, X>
263where
264 X: SampleUniform + PartialOrd,
265{
266 fn clone(&self) -> Self {
267 WeightedIndexIter {
268 weighted_index: self.weighted_index,
269 index: self.index,
270 }
271 }
272}
273
274impl<X> Iterator for WeightedIndexIter<'_, X>
275where
276 X: for<'b> core::ops::SubAssign<&'b X> + SampleUniform + PartialOrd + Clone,
277{
278 type Item = X;
279
280 fn next(&mut self) -> Option<Self::Item> {
281 match self.weighted_index.weight(self.index) {
282 None => None,
283 Some(weight) => {
284 self.index += 1;
285 Some(weight)
286 }
287 }
288 }
289}
290
291impl<X: SampleUniform + PartialOrd + Clone> WeightedIndex<X> {
292 pub fn weight(&self, index: usize) -> Option<X>
309 where
310 X: for<'a> core::ops::SubAssign<&'a X>,
311 {
312 use core::cmp::Ordering::*;
313
314 let mut weight = match index.cmp(&self.cumulative_weights.len()) {
315 Less => self.cumulative_weights[index].clone(),
316 Equal => self.total_weight.clone(),
317 Greater => return None,
318 };
319
320 if index > 0 {
321 weight -= &self.cumulative_weights[index - 1];
322 }
323 Some(weight)
324 }
325
326 pub fn weights(&self) -> WeightedIndexIter<'_, X>
343 where
344 X: for<'a> core::ops::SubAssign<&'a X>,
345 {
346 WeightedIndexIter {
347 weighted_index: self,
348 index: 0,
349 }
350 }
351
352 pub fn total_weight(&self) -> X {
354 self.total_weight.clone()
355 }
356}
357
358impl<X> Distribution<usize> for WeightedIndex<X>
359where
360 X: SampleUniform + PartialOrd,
361{
362 fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> usize {
363 let chosen_weight = self.weight_distribution.sample(rng);
364 self.cumulative_weights
366 .partition_point(|w| w <= &chosen_weight)
367 }
368}
369
370#[cfg(test)]
371mod test {
372 use super::*;
373 use crate::RngExt;
374
375 #[cfg(feature = "serde")]
376 #[test]
377 fn test_weightedindex_serde() {
378 let weighted_index = WeightedIndex::new([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]).unwrap();
379
380 let ser_weighted_index = postcard::to_allocvec(&weighted_index).unwrap();
381 let de_weighted_index: WeightedIndex<i32> =
382 postcard::from_bytes(&ser_weighted_index).unwrap();
383
384 assert_eq!(
385 de_weighted_index.cumulative_weights,
386 weighted_index.cumulative_weights
387 );
388 assert_eq!(de_weighted_index.total_weight, weighted_index.total_weight);
389 }
390
391 #[test]
392 fn test_accepting_nan() {
393 assert_eq!(
394 WeightedIndex::new([f32::NAN, 0.5]).unwrap_err(),
395 Error::InvalidWeight,
396 );
397 assert_eq!(
398 WeightedIndex::new([f32::NAN]).unwrap_err(),
399 Error::InvalidWeight,
400 );
401 assert_eq!(
402 WeightedIndex::new([0.5, f32::NAN]).unwrap_err(),
403 Error::InvalidWeight,
404 );
405
406 assert_eq!(
407 WeightedIndex::new([0.5, 7.0])
408 .unwrap()
409 .update_weights(&[(0, &f32::NAN)])
410 .unwrap_err(),
411 Error::InvalidWeight,
412 )
413 }
414
415 #[test]
416 #[cfg_attr(miri, ignore)] fn test_weightedindex() {
418 let mut r = crate::test::rng(700);
419 const N_REPS: u32 = 5000;
420 let weights = [1u32, 2, 3, 0, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7];
421 let total_weight = weights.iter().sum::<u32>() as f32;
422
423 let verify = |result: [i32; 14]| {
424 for (i, count) in result.iter().enumerate() {
425 let exp = (weights[i] * N_REPS) as f32 / total_weight;
426 let mut err = (*count as f32 - exp).abs();
427 if err != 0.0 {
428 err /= exp;
429 }
430 assert!(err <= 0.25);
431 }
432 };
433
434 let mut chosen = [0i32; 14];
436 let distr = WeightedIndex::new(weights.to_vec()).unwrap();
437 for _ in 0..N_REPS {
438 chosen[distr.sample(&mut r)] += 1;
439 }
440 verify(chosen);
441
442 chosen = [0i32; 14];
444 let distr = WeightedIndex::new(&weights[..]).unwrap();
445 for _ in 0..N_REPS {
446 chosen[distr.sample(&mut r)] += 1;
447 }
448 verify(chosen);
449
450 chosen = [0i32; 14];
452 let distr = WeightedIndex::new(weights.iter()).unwrap();
453 for _ in 0..N_REPS {
454 chosen[distr.sample(&mut r)] += 1;
455 }
456 verify(chosen);
457
458 for _ in 0..5 {
459 assert_eq!(WeightedIndex::new([0, 1]).unwrap().sample(&mut r), 1);
460 assert_eq!(WeightedIndex::new([1, 0]).unwrap().sample(&mut r), 0);
461 assert_eq!(
462 WeightedIndex::new([0, 0, 0, 0, 10, 0])
463 .unwrap()
464 .sample(&mut r),
465 4
466 );
467 }
468
469 assert_eq!(
470 WeightedIndex::new(&[10][0..0]).unwrap_err(),
471 Error::InvalidInput
472 );
473 assert_eq!(
474 WeightedIndex::new([0]).unwrap_err(),
475 Error::InsufficientNonZero
476 );
477 assert_eq!(
478 WeightedIndex::new([10, 20, -1, 30]).unwrap_err(),
479 Error::InvalidWeight
480 );
481 assert_eq!(
482 WeightedIndex::new([-10, 20, 1, 30]).unwrap_err(),
483 Error::InvalidWeight
484 );
485 assert_eq!(WeightedIndex::new([-10]).unwrap_err(), Error::InvalidWeight);
486 }
487
488 #[test]
489 fn test_update_weights() {
490 let data = [
491 (
492 &[10u32, 2, 3, 4][..],
493 &[(1, &100), (2, &4)][..], &[10, 100, 4, 4][..],
495 ),
496 (
497 &[1u32, 2, 3, 0, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7][..],
498 &[(2, &1), (5, &1), (13, &100)][..], &[1u32, 2, 1, 0, 5, 1, 7, 1, 2, 3, 4, 5, 6, 100][..],
500 ),
501 ];
502
503 for (weights, update, expected_weights) in data.iter() {
504 let total_weight = weights.iter().sum::<u32>();
505 let mut distr = WeightedIndex::new(weights.to_vec()).unwrap();
506 assert_eq!(distr.total_weight, total_weight);
507
508 distr.update_weights(update).unwrap();
509 let expected_total_weight = expected_weights.iter().sum::<u32>();
510 let expected_distr = WeightedIndex::new(expected_weights.to_vec()).unwrap();
511 assert_eq!(distr.total_weight, expected_total_weight);
512 assert_eq!(distr.total_weight, expected_distr.total_weight);
513 assert_eq!(distr.cumulative_weights, expected_distr.cumulative_weights);
514 }
515 }
516
517 #[test]
518 fn test_update_weights_errors() {
519 let data = [
520 (
521 &[1i32, 0, 0][..],
522 &[(0, &0)][..],
523 Error::InsufficientNonZero,
524 ),
525 (
526 &[10, 10, 10, 10][..],
527 &[(1, &-11)][..],
528 Error::InvalidWeight, ),
530 (
531 &[1, 2, 3, 4, 5][..],
532 &[(1, &5), (0, &5)][..], Error::InvalidInput,
534 ),
535 (
536 &[1][..],
537 &[(1, &1)][..], Error::InvalidInput,
539 ),
540 ];
541
542 for (weights, update, err) in data.iter() {
543 let total_weight = weights.iter().sum::<i32>();
544 let mut distr = WeightedIndex::new(weights.to_vec()).unwrap();
545 assert_eq!(distr.total_weight, total_weight);
546 match distr.update_weights(update) {
547 Ok(_) => panic!("Expected update_weights to fail, but it succeeded"),
548 Err(e) => assert_eq!(e, *err),
549 }
550 }
551 }
552
553 #[test]
554 fn test_weight_at() {
555 let data = [
556 &[1][..],
557 &[10, 2, 3, 4][..],
558 &[1, 2, 3, 0, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7][..],
559 &[u32::MAX][..],
560 ];
561
562 for weights in data.iter() {
563 let distr = WeightedIndex::new(weights.to_vec()).unwrap();
564 for (i, weight) in weights.iter().enumerate() {
565 assert_eq!(distr.weight(i), Some(*weight));
566 }
567 assert_eq!(distr.weight(weights.len()), None);
568 }
569 }
570
571 #[test]
572 fn test_weights() {
573 let data = [
574 &[1][..],
575 &[10, 2, 3, 4][..],
576 &[1, 2, 3, 0, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7][..],
577 &[u32::MAX][..],
578 ];
579
580 for weights in data.iter() {
581 let distr = WeightedIndex::new(weights.to_vec()).unwrap();
582 assert_eq!(distr.weights().collect::<Vec<_>>(), weights.to_vec());
583 }
584 }
585
586 #[test]
587 fn value_stability() {
588 fn test_samples<X: Weight + SampleUniform + PartialOrd, I>(
589 weights: I,
590 buf: &mut [usize],
591 expected: &[usize],
592 ) where
593 I: IntoIterator,
594 I::Item: SampleBorrow<X>,
595 {
596 assert_eq!(buf.len(), expected.len());
597 let distr = WeightedIndex::new(weights).unwrap();
598 let mut rng = crate::test::rng(701);
599 for r in buf.iter_mut() {
600 *r = rng.sample(&distr);
601 }
602 assert_eq!(buf, expected);
603 }
604
605 let mut buf = [0; 10];
606 test_samples(
607 [1i32, 1, 1, 1, 1, 1, 1, 1, 1],
608 &mut buf,
609 &[0, 6, 2, 6, 3, 4, 7, 8, 2, 5],
610 );
611 test_samples(
612 [0.7f32, 0.1, 0.1, 0.1],
613 &mut buf,
614 &[0, 0, 0, 1, 0, 0, 2, 3, 0, 0],
615 );
616 test_samples(
617 [1.0f64, 0.999, 0.998, 0.997],
618 &mut buf,
619 &[2, 2, 1, 3, 2, 1, 3, 3, 2, 1],
620 );
621 }
622
623 #[test]
624 fn weighted_index_distributions_can_be_compared() {
625 assert_eq!(WeightedIndex::new([1, 2]), WeightedIndex::new([1, 2]));
626 }
627
628 #[test]
629 fn overflow() {
630 assert_eq!(WeightedIndex::new([2, usize::MAX]), Err(Error::Overflow));
631 }
632}