1#![allow(clippy::cast_possible_truncation, reason = "needs triage")]
2#![allow(clippy::undocumented_unsafe_blocks, reason = "TODO")]
3
4use core::fmt;
5
6use rand_core::{
7 Infallible, SeedableRng, TryCryptoRng, TryRng,
8 block::{BlockRng, Generator},
9};
10
11#[cfg(feature = "zeroize")]
12use zeroize::{Zeroize, ZeroizeOnDrop};
13
14use crate::{
15 ChaChaCore, R8, R12, R20, Rounds, backends,
16 variants::{Legacy, Variant},
17};
18
19use cfg_if::cfg_if;
20
21pub type Seed = [u8; 32];
23
24pub type SerializedRngState = [u8; 49];
26
27pub(crate) const BLOCK_WORDS: u8 = 16;
29
30const BUF_BLOCKS: u8 = 4;
32const BUFFER_SIZE: usize = (BLOCK_WORDS * BUF_BLOCKS) as usize;
34
35impl<R: Rounds, V: Variant> SeedableRng for ChaChaCore<R, V> {
36 type Seed = Seed;
37
38 #[inline]
39 fn from_seed(seed: Self::Seed) -> Self {
40 ChaChaCore::new_internal(&seed, &[0u8; 8])
41 }
42}
43
44impl<R: Rounds, V: Variant> Generator for ChaChaCore<R, V> {
45 type Output = [u32; BUFFER_SIZE];
46
47 fn generate(&mut self, buffer: &mut [u32; BUFFER_SIZE]) {
50 cfg_if! {
51 if #[cfg(chacha20_backend = "soft")] {
52 backends::soft::Backend(self).gen_ks_blocks(buffer);
53 } else if #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] {
54 cfg_if! {
55 if #[cfg(any(chacha20_backend = "avx2", chacha20_backend = "avx512"))] {
57 unsafe {
58 backends::avx2::rng_inner::<R, V>(self, buffer);
59 }
60 } else if #[cfg(chacha20_backend = "sse2")] {
61 unsafe {
62 backends::sse2::rng_inner::<R, V>(self, buffer);
63 }
64 } else {
65 #[cfg(chacha20_avx512)]
66 let (_avx512_token, avx2_token, sse2_token) = self.tokens;
67 #[cfg(not(chacha20_avx512))]
68 let (avx2_token, sse2_token) = self.tokens;
69
70 if avx2_token.get() {
71 unsafe {
72 backends::avx2::rng_inner::<R, V>(self, buffer);
73 }
74 } else if sse2_token.get() {
75 unsafe {
76 backends::sse2::rng_inner::<R, V>(self, buffer);
77 }
78 } else {
79 backends::soft::Backend(self).gen_ks_blocks(buffer);
80 }
81 }
82 }
83 } else if #[cfg(all(target_arch = "aarch64", target_feature = "neon"))] {
84 unsafe {
86 backends::neon::rng_inner::<R, V>(self, buffer);
87 }
88 } else {
89 backends::soft::Backend(self).gen_ks_blocks(buffer);
90 }
91 }
92 }
93
94 #[cfg(feature = "zeroize")]
97 fn drop(&mut self, output: &mut Self::Output) {
98 output.zeroize();
99 }
100}
101
102macro_rules! impl_chacha_rng {
103 ($Rng:ident, $rounds:ident) => {
104 #[doc = concat!("use chacha20::", stringify!($Rng), ";")]
121 #[doc = concat!("let mut rng = ", stringify!($Rng), "::from_seed(seed);")]
125 pub struct $Rng {
135 core: BlockRng<ChaChaCore<$rounds, Legacy>>,
136 }
137
138 impl SeedableRng for $Rng {
139 type Seed = Seed;
140
141 #[inline]
142 fn from_seed(seed: Self::Seed) -> Self {
143 let core = ChaChaCore::new_internal(&seed, &[0u8; 8]);
144 Self {
145 core: BlockRng::new(core),
146 }
147 }
148 }
149
150 impl TryRng for $Rng {
151 type Error = Infallible;
152
153 #[inline]
154 fn try_next_u32(&mut self) -> Result<u32, Self::Error> {
155 Ok(self.core.next_word())
156 }
157 #[inline]
158 fn try_next_u64(&mut self) -> Result<u64, Self::Error> {
159 Ok(self.core.next_u64_from_u32())
160 }
161 #[inline]
162 fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), Self::Error> {
163 self.core.fill_bytes(dest);
164 Ok(())
165 }
166 }
167
168 impl TryCryptoRng for $Rng {}
169
170 #[cfg(feature = "zeroize")]
171 impl ZeroizeOnDrop for $Rng {}
172
173 impl PartialEq<$Rng> for $Rng {
177 fn eq(&self, rhs: &$Rng) -> bool {
178 (self.get_seed() == rhs.get_seed())
179 && (self.get_stream() == rhs.get_stream())
180 && (self.get_word_pos() == rhs.get_word_pos())
181 }
182 }
183
184 impl Eq for $Rng {}
185
186 impl fmt::Debug for $Rng {
188 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
189 write!(f, concat!(stringify!($Rng), " {{ ... }}"))
190 }
191 }
192
193 impl $Rng {
194 #[inline]
201 #[must_use]
202 pub fn get_word_pos(&self) -> u128 {
203 let mut block_counter = (u64::from(self.core.core.state[13]) << 32)
204 | u64::from(self.core.core.state[12]);
205 if self.core.word_offset() != 0 {
206 block_counter = block_counter.wrapping_sub(u64::from(BUF_BLOCKS));
207 }
208 let word_pos = u128::from(block_counter) * u128::from(BLOCK_WORDS)
209 + self.core.word_offset() as u128;
210 word_pos & ((1 << 68) - 1)
212 }
213
214 #[inline]
224 pub fn set_word_pos(&mut self, word_offset: u128) {
225 let index = (word_offset % u128::from(BLOCK_WORDS)) as usize;
226 let counter = word_offset / u128::from(BLOCK_WORDS);
227 self.core.core.state[12] = counter as u32;
229 self.core.core.state[13] = (counter >> 32) as u32;
230 self.core.reset_and_skip(index);
231 }
232
233 #[inline]
241 #[allow(unused)]
242 pub fn set_block_pos(&mut self, block_pos: u64) {
243 self.core.reset_and_skip(0);
244 self.core.core.set_block_pos(block_pos);
245 }
246
247 #[inline]
249 #[allow(unused)]
250 #[must_use]
251 pub fn get_block_pos(&self) -> u64 {
252 let counter = self.core.core.get_block_pos();
253 let offset = self.core.word_offset();
254 if offset != 0 {
255 counter - u64::from(BUF_BLOCKS) + offset as u64 / 16
256 } else {
257 counter
258 }
259 }
260
261 #[inline]
263 pub fn set_stream(&mut self, stream: u64) {
264 self.core.core.state[14] = stream as u32;
265 self.core.core.state[15] = (stream >> 32) as u32;
266 self.set_block_pos(0);
267 }
268
269 #[inline]
271 #[must_use]
272 pub fn get_stream(&self) -> u64 {
273 let mut result = [0u8; 8];
274 result[..4].copy_from_slice(&self.core.core.state[14].to_le_bytes());
275 result[4..].copy_from_slice(&self.core.core.state[15].to_le_bytes());
276 u64::from_le_bytes(result)
277 }
278
279 #[inline]
281 #[must_use]
282 pub fn get_seed(&self) -> [u8; 32] {
283 let seed = &self.core.core.state[4..12];
284 let mut result = [0u8; 32];
285 for (src, dst) in seed.iter().zip(result.chunks_exact_mut(4)) {
286 dst.copy_from_slice(&src.to_le_bytes())
287 }
288 result
289 }
290
291 #[inline]
297 pub fn serialize_state(&self) -> SerializedRngState {
298 let seed = self.get_seed();
299 let stream = self.get_stream().to_le_bytes();
300 let word_pos = self.get_word_pos().to_le_bytes();
301
302 let mut res = [0u8; 49];
303 let (seed_dst, res_rem) = res.split_at_mut(32);
304 let (stream_dst, word_pos_dst) = res_rem.split_at_mut(8);
305
306 seed_dst.copy_from_slice(&seed);
307 stream_dst.copy_from_slice(&stream);
308 word_pos_dst.copy_from_slice(&word_pos[..9]);
309
310 debug_assert_eq!(&word_pos[9..], &[0u8; 7]);
311
312 res
313 }
314
315 #[inline]
317 pub fn deserialize_state(state: &SerializedRngState) -> Self {
318 let (seed, state_rem) = state.split_at(32);
319 let (stream, word_pos_raw) = state_rem.split_at(8);
320
321 let seed: &[u8; 32] = seed.try_into().expect("seed.len() is equal to 32");
322 let stream: &[u8; 8] = stream.try_into().expect("stream.len() is equal to 8");
323
324 let mut word_pos_buf = [0u8; 16];
327 word_pos_buf[..9].copy_from_slice(word_pos_raw);
328 let word_pos = u128::from_le_bytes(word_pos_buf);
329
330 let core = ChaChaCore::new_internal(seed, stream);
331 let mut res = Self {
332 core: BlockRng::new(core),
333 };
334
335 res.set_word_pos(word_pos);
336 res
337 }
338 }
339 };
340}
341
342impl_chacha_rng!(ChaCha8Rng, R8);
343impl_chacha_rng!(ChaCha12Rng, R12);
344impl_chacha_rng!(ChaCha20Rng, R20);