chacha20/
rng.rs

1#![allow(clippy::cast_possible_truncation, reason = "needs triage")]
2#![allow(clippy::undocumented_unsafe_blocks, reason = "TODO")]
3
4use core::fmt;
5
6use rand_core::{
7    Infallible, SeedableRng, TryCryptoRng, TryRng,
8    block::{BlockRng, Generator},
9};
10
11#[cfg(feature = "zeroize")]
12use zeroize::{Zeroize, ZeroizeOnDrop};
13
14use crate::{
15    ChaChaCore, R8, R12, R20, Rounds, backends,
16    variants::{Legacy, Variant},
17};
18
19use cfg_if::cfg_if;
20
21/// Seed value used to initialize ChaCha-based RNGs.
22pub type Seed = [u8; 32];
23
24/// Serialized RNG state.
25pub type SerializedRngState = [u8; 49];
26
27/// Number of 32-bit words per ChaCha block (fixed by algorithm definition).
28pub(crate) const BLOCK_WORDS: u8 = 16;
29
30/// Number of blocks generated by RNG core.
31const BUF_BLOCKS: u8 = 4;
32/// Buffer size in words used by buffered RNG.
33const BUFFER_SIZE: usize = (BLOCK_WORDS * BUF_BLOCKS) as usize;
34
35impl<R: Rounds, V: Variant> SeedableRng for ChaChaCore<R, V> {
36    type Seed = Seed;
37
38    #[inline]
39    fn from_seed(seed: Self::Seed) -> Self {
40        ChaChaCore::new_internal(&seed, &[0u8; 8])
41    }
42}
43
44impl<R: Rounds, V: Variant> Generator for ChaChaCore<R, V> {
45    type Output = [u32; BUFFER_SIZE];
46
47    /// Generates 4 blocks in parallel with avx2 & neon, but merely fills
48    /// 4 blocks with sse2 & soft
49    fn generate(&mut self, buffer: &mut [u32; BUFFER_SIZE]) {
50        cfg_if! {
51            if #[cfg(chacha20_backend = "soft")] {
52                backends::soft::Backend(self).gen_ks_blocks(buffer);
53            } else if #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] {
54                cfg_if! {
55                    // AVX-512 doesn't support RNG, so use AVX-2 instead
56                    if #[cfg(any(chacha20_backend = "avx2", chacha20_backend = "avx512"))] {
57                        unsafe {
58                            backends::avx2::rng_inner::<R, V>(self, buffer);
59                        }
60                    } else if #[cfg(chacha20_backend = "sse2")] {
61                        unsafe {
62                            backends::sse2::rng_inner::<R, V>(self, buffer);
63                        }
64                    } else {
65                        #[cfg(chacha20_avx512)]
66                        let (_avx512_token, avx2_token, sse2_token) = self.tokens;
67                        #[cfg(not(chacha20_avx512))]
68                        let (avx2_token, sse2_token) = self.tokens;
69
70                        if avx2_token.get() {
71                            unsafe {
72                                backends::avx2::rng_inner::<R, V>(self, buffer);
73                            }
74                        } else if sse2_token.get() {
75                            unsafe {
76                                backends::sse2::rng_inner::<R, V>(self, buffer);
77                            }
78                        } else {
79                            backends::soft::Backend(self).gen_ks_blocks(buffer);
80                        }
81                    }
82                }
83            } else if #[cfg(all(target_arch = "aarch64", target_feature = "neon"))] {
84                // SAFETY: we have used conditional compilation to ensure NEON is available
85                unsafe {
86                    backends::neon::rng_inner::<R, V>(self, buffer);
87                }
88            } else {
89                backends::soft::Backend(self).gen_ks_blocks(buffer);
90            }
91        }
92    }
93
94    // `Drop` impl of `BlockRng` calls this method and passes reference to
95    // its internal buffer in `output`. So we zeroize its contents here.
96    #[cfg(feature = "zeroize")]
97    fn drop(&mut self, output: &mut Self::Output) {
98        output.zeroize();
99    }
100}
101
102macro_rules! impl_chacha_rng {
103    ($Rng:ident, $rounds:ident) => {
104        /// A cryptographically secure random number generator that uses the ChaCha stream cipher.
105        ///
106        /// See the [crate docs][crate] for more information about the underlying stream cipher.
107        ///
108        /// This RNG implementation uses a 64-bit counter and 64-bit stream identifier (a.k.a nonce).
109        /// A 64-bit counter over 64-byte (16 word) blocks allows 1 ZiB of output before cycling,
110        /// and the stream identifier allows 2<sup>64</sup> unique streams of output per seed.
111        /// Both counter and stream are initialized to zero but may be set via the [`set_word_pos`]
112        /// and [`set_stream`] methods.
113        ///
114        /// [`set_word_pos`]: Self::set_word_pos
115        /// [`set_stream`]: Self::set_stream
116        ///
117        /// # Example
118        ///
119        /// ```rust
120        #[doc = concat!("use chacha20::", stringify!($Rng), ";")]
121        /// use rand_core::{SeedableRng, Rng};
122        ///
123        /// let seed = [42u8; 32];
124        #[doc = concat!("let mut rng = ", stringify!($Rng), "::from_seed(seed);")]
125        ///
126        /// let random_u32 = rng.next_u32();
127        /// let random_u64 = rng.next_u64();
128        ///
129        /// let mut random_bytes = [0u8; 3];
130        /// rng.fill_bytes(&mut random_bytes);
131        /// ```
132        ///
133        /// See the [`rand`](https://docs.rs/rand/) crate for more advanced RNG functionality.
134        pub struct $Rng {
135            core: BlockRng<ChaChaCore<$rounds, Legacy>>,
136        }
137
138        impl SeedableRng for $Rng {
139            type Seed = Seed;
140
141            #[inline]
142            fn from_seed(seed: Self::Seed) -> Self {
143                let core = ChaChaCore::new_internal(&seed, &[0u8; 8]);
144                Self {
145                    core: BlockRng::new(core),
146                }
147            }
148        }
149
150        impl TryRng for $Rng {
151            type Error = Infallible;
152
153            #[inline]
154            fn try_next_u32(&mut self) -> Result<u32, Self::Error> {
155                Ok(self.core.next_word())
156            }
157            #[inline]
158            fn try_next_u64(&mut self) -> Result<u64, Self::Error> {
159                Ok(self.core.next_u64_from_u32())
160            }
161            #[inline]
162            fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), Self::Error> {
163                self.core.fill_bytes(dest);
164                Ok(())
165            }
166        }
167
168        impl TryCryptoRng for $Rng {}
169
170        #[cfg(feature = "zeroize")]
171        impl ZeroizeOnDrop for $Rng {}
172
173        // We use custom implementation of `PartialEq` because RNG states
174        // may buffer different parts of the same keystream, while keeping
175        // buffer cursor pointing towards the same keystream point.
176        impl PartialEq<$Rng> for $Rng {
177            fn eq(&self, rhs: &$Rng) -> bool {
178                (self.get_seed() == rhs.get_seed())
179                    && (self.get_stream() == rhs.get_stream())
180                    && (self.get_word_pos() == rhs.get_word_pos())
181            }
182        }
183
184        impl Eq for $Rng {}
185
186        // Custom Debug implementation that does not expose the internal state
187        impl fmt::Debug for $Rng {
188            fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
189                write!(f, concat!(stringify!($Rng), " {{ ... }}"))
190            }
191        }
192
193        impl $Rng {
194            /// Get the offset from the start of the stream, in 32-bit words.
195            ///
196            /// Since the generated blocks are 64 words (2<sup>6</sup>) long and the
197            /// counter is 64-bits, the offset is a 68-bit number. Sub-word offsets are
198            /// not supported, hence the result can simply be multiplied by 4 to get a
199            /// byte-offset.
200            #[inline]
201            #[must_use]
202            pub fn get_word_pos(&self) -> u128 {
203                let mut block_counter = (u64::from(self.core.core.state[13]) << 32)
204                    | u64::from(self.core.core.state[12]);
205                if self.core.word_offset() != 0 {
206                    block_counter = block_counter.wrapping_sub(u64::from(BUF_BLOCKS));
207                }
208                let word_pos = u128::from(block_counter) * u128::from(BLOCK_WORDS)
209                    + self.core.word_offset() as u128;
210                // eliminate bits above the 68th bit
211                word_pos & ((1 << 68) - 1)
212            }
213
214            /// Set the offset from the start of the stream, in 32-bit words.
215            ///
216            /// **This value will be erased when calling `set_stream()`,
217            /// so call `set_stream()` before calling `set_word_pos()`**
218            /// if you intend on using both of them together.
219            ///
220            /// As with `get_word_pos`, we use a 68-bit number. Since the generator
221            /// simply cycles at the end of its period (1 ZiB), we ignore the upper
222            /// 60 bits.
223            #[inline]
224            pub fn set_word_pos(&mut self, word_offset: u128) {
225                let index = (word_offset % u128::from(BLOCK_WORDS)) as usize;
226                let counter = word_offset / u128::from(BLOCK_WORDS);
227                //self.set_block_pos(counter as u64);
228                self.core.core.state[12] = counter as u32;
229                self.core.core.state[13] = (counter >> 32) as u32;
230                self.core.reset_and_skip(index);
231            }
232
233            /// Sets the block pos and resets the RNG's index.
234            ///
235            /// **This value will be erased when calling `set_stream()`,
236            /// so call `set_stream()` before calling `set_block_pos()`**
237            /// if you intend on using both of them together.
238            ///
239            /// The word pos will be equal to `block_pos * 16 words per block`.
240            #[inline]
241            #[allow(unused)]
242            pub fn set_block_pos(&mut self, block_pos: u64) {
243                self.core.reset_and_skip(0);
244                self.core.core.set_block_pos(block_pos);
245            }
246
247            /// Get the block pos.
248            #[inline]
249            #[allow(unused)]
250            #[must_use]
251            pub fn get_block_pos(&self) -> u64 {
252                let counter = self.core.core.get_block_pos();
253                let offset = self.core.word_offset();
254                if offset != 0 {
255                    counter - u64::from(BUF_BLOCKS) + offset as u64 / 16
256                } else {
257                    counter
258                }
259            }
260
261            /// Set the stream ID and reset the `word_pos` to 0.
262            #[inline]
263            pub fn set_stream(&mut self, stream: u64) {
264                self.core.core.state[14] = stream as u32;
265                self.core.core.state[15] = (stream >> 32) as u32;
266                self.set_block_pos(0);
267            }
268
269            /// Get the stream number (nonce).
270            #[inline]
271            #[must_use]
272            pub fn get_stream(&self) -> u64 {
273                let mut result = [0u8; 8];
274                result[..4].copy_from_slice(&self.core.core.state[14].to_le_bytes());
275                result[4..].copy_from_slice(&self.core.core.state[15].to_le_bytes());
276                u64::from_le_bytes(result)
277            }
278
279            /// Get the RNG seed.
280            #[inline]
281            #[must_use]
282            pub fn get_seed(&self) -> [u8; 32] {
283                let seed = &self.core.core.state[4..12];
284                let mut result = [0u8; 32];
285                for (src, dst) in seed.iter().zip(result.chunks_exact_mut(4)) {
286                    dst.copy_from_slice(&src.to_le_bytes())
287                }
288                result
289            }
290
291            /// Serialize RNG state.
292            ///
293            /// # Warning
294            /// Leaking serialized RNG state to an attacker defeats security properties
295            /// provided by the RNG.
296            #[inline]
297            pub fn serialize_state(&self) -> SerializedRngState {
298                let seed = self.get_seed();
299                let stream = self.get_stream().to_le_bytes();
300                let word_pos = self.get_word_pos().to_le_bytes();
301
302                let mut res = [0u8; 49];
303                let (seed_dst, res_rem) = res.split_at_mut(32);
304                let (stream_dst, word_pos_dst) = res_rem.split_at_mut(8);
305
306                seed_dst.copy_from_slice(&seed);
307                stream_dst.copy_from_slice(&stream);
308                word_pos_dst.copy_from_slice(&word_pos[..9]);
309
310                debug_assert_eq!(&word_pos[9..], &[0u8; 7]);
311
312                res
313            }
314
315            /// Deserialize RNG state.
316            #[inline]
317            pub fn deserialize_state(state: &SerializedRngState) -> Self {
318                let (seed, state_rem) = state.split_at(32);
319                let (stream, word_pos_raw) = state_rem.split_at(8);
320
321                let seed: &[u8; 32] = seed.try_into().expect("seed.len() is equal to 32");
322                let stream: &[u8; 8] = stream.try_into().expect("stream.len() is equal to 8");
323
324                // Note that we use only 68 bits from `word_pos_raw`, i.e. 4 remaining bits
325                // get ignored and should be equal to zero in practice.
326                let mut word_pos_buf = [0u8; 16];
327                word_pos_buf[..9].copy_from_slice(word_pos_raw);
328                let word_pos = u128::from_le_bytes(word_pos_buf);
329
330                let core = ChaChaCore::new_internal(seed, stream);
331                let mut res = Self {
332                    core: BlockRng::new(core),
333                };
334
335                res.set_word_pos(word_pos);
336                res
337            }
338        }
339    };
340}
341
342impl_chacha_rng!(ChaCha8Rng, R8);
343impl_chacha_rng!(ChaCha12Rng, R12);
344impl_chacha_rng!(ChaCha20Rng, R20);