apache_avro/
decimal.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use crate::{AvroResult, Error};
19use num_bigint::{BigInt, Sign};
20use serde::{de::SeqAccess, Deserialize, Serialize, Serializer};
21
22#[derive(Debug, Clone, Eq)]
23pub struct Decimal {
24    value: BigInt,
25    len: usize,
26}
27
28impl Serialize for Decimal {
29    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
30    where
31        S: Serializer,
32    {
33        match self.to_vec() {
34            Ok(ref bytes) => serializer.serialize_bytes(bytes),
35            Err(e) => Err(serde::ser::Error::custom(e)),
36        }
37    }
38}
39impl<'de> Deserialize<'de> for Decimal {
40    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
41    where
42        D: serde::Deserializer<'de>,
43    {
44        struct DecimalVisitor;
45        impl<'de> serde::de::Visitor<'de> for DecimalVisitor {
46            type Value = Decimal;
47
48            fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
49                formatter.write_str("a byte slice or seq of bytes")
50            }
51
52            fn visit_bytes<E>(self, v: &[u8]) -> Result<Self::Value, E>
53            where
54                E: serde::de::Error,
55            {
56                Ok(Decimal::from(v))
57            }
58            fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
59            where
60                A: SeqAccess<'de>,
61            {
62                let mut bytes = Vec::new();
63                while let Some(value) = seq.next_element::<u8>()? {
64                    bytes.push(value);
65                }
66
67                Ok(Decimal::from(bytes))
68            }
69        }
70        deserializer.deserialize_bytes(DecimalVisitor)
71    }
72}
73
74// We only care about value equality, not byte length. Can two equal `BigInt`s have two different
75// byte lengths?
76impl PartialEq for Decimal {
77    fn eq(&self, other: &Self) -> bool {
78        self.value == other.value
79    }
80}
81
82impl Decimal {
83    pub(crate) fn len(&self) -> usize {
84        self.len
85    }
86
87    pub(crate) fn to_vec(&self) -> AvroResult<Vec<u8>> {
88        self.to_sign_extended_bytes_with_len(self.len)
89    }
90
91    pub(crate) fn to_sign_extended_bytes_with_len(&self, len: usize) -> AvroResult<Vec<u8>> {
92        let sign_byte = 0xFF * u8::from(self.value.sign() == Sign::Minus);
93        let mut decimal_bytes = vec![sign_byte; len];
94        let raw_bytes = self.value.to_signed_bytes_be();
95        let num_raw_bytes = raw_bytes.len();
96        let start_byte_index = len.checked_sub(num_raw_bytes).ok_or(Error::SignExtend {
97            requested: len,
98            needed: num_raw_bytes,
99        })?;
100        decimal_bytes[start_byte_index..].copy_from_slice(&raw_bytes);
101        Ok(decimal_bytes)
102    }
103}
104
105impl From<Decimal> for BigInt {
106    fn from(decimal: Decimal) -> Self {
107        decimal.value
108    }
109}
110
111/// Gets the internal byte array representation of a referenced decimal.
112/// Usage:
113/// ```
114/// use apache_avro::Decimal;
115/// use std::convert::TryFrom;
116///
117/// let decimal = Decimal::from(vec![1, 24]);
118/// let maybe_bytes = <Vec<u8>>::try_from(&decimal);
119/// ```
120impl std::convert::TryFrom<&Decimal> for Vec<u8> {
121    type Error = Error;
122
123    fn try_from(decimal: &Decimal) -> Result<Self, Self::Error> {
124        decimal.to_vec()
125    }
126}
127
128/// Gets the internal byte array representation of an owned decimal.
129/// Usage:
130/// ```
131/// use apache_avro::Decimal;
132/// use std::convert::TryFrom;
133///
134/// let decimal = Decimal::from(vec![1, 24]);
135/// let maybe_bytes = <Vec<u8>>::try_from(decimal);
136/// ```
137impl std::convert::TryFrom<Decimal> for Vec<u8> {
138    type Error = Error;
139
140    fn try_from(decimal: Decimal) -> Result<Self, Self::Error> {
141        decimal.to_vec()
142    }
143}
144
145impl<T: AsRef<[u8]>> From<T> for Decimal {
146    fn from(bytes: T) -> Self {
147        let bytes_ref = bytes.as_ref();
148        Self {
149            value: BigInt::from_signed_bytes_be(bytes_ref),
150            len: bytes_ref.len(),
151        }
152    }
153}
154
155#[cfg(test)]
156mod tests {
157    use super::*;
158    use apache_avro_test_helper::TestResult;
159    use pretty_assertions::assert_eq;
160
161    #[test]
162    fn test_decimal_from_bytes_from_ref_decimal() -> TestResult {
163        let input = vec![1, 24];
164        let d = Decimal::from(&input);
165
166        let output = <Vec<u8>>::try_from(&d)?;
167        assert_eq!(output, input);
168
169        Ok(())
170    }
171
172    #[test]
173    fn test_decimal_from_bytes_from_owned_decimal() -> TestResult {
174        let input = vec![1, 24];
175        let d = Decimal::from(&input);
176
177        let output = <Vec<u8>>::try_from(d)?;
178        assert_eq!(output, input);
179
180        Ok(())
181    }
182
183    #[test]
184    fn avro_3949_decimal_serde() -> TestResult {
185        let decimal = Decimal::from(&[1, 2, 3]);
186
187        let ser = serde_json::to_string(&decimal)?;
188        let de = serde_json::from_str(&ser)?;
189        std::assert_eq!(decimal, de);
190
191        Ok(())
192    }
193}