1use crate::{AvroResult, Error};
19use num_bigint::{BigInt, Sign};
20use serde::{de::SeqAccess, Deserialize, Serialize, Serializer};
21
22#[derive(Debug, Clone, Eq)]
23pub struct Decimal {
24 value: BigInt,
25 len: usize,
26}
27
28impl Serialize for Decimal {
29 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
30 where
31 S: Serializer,
32 {
33 match self.to_vec() {
34 Ok(ref bytes) => serializer.serialize_bytes(bytes),
35 Err(e) => Err(serde::ser::Error::custom(e)),
36 }
37 }
38}
39impl<'de> Deserialize<'de> for Decimal {
40 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
41 where
42 D: serde::Deserializer<'de>,
43 {
44 struct DecimalVisitor;
45 impl<'de> serde::de::Visitor<'de> for DecimalVisitor {
46 type Value = Decimal;
47
48 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
49 formatter.write_str("a byte slice or seq of bytes")
50 }
51
52 fn visit_bytes<E>(self, v: &[u8]) -> Result<Self::Value, E>
53 where
54 E: serde::de::Error,
55 {
56 Ok(Decimal::from(v))
57 }
58 fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
59 where
60 A: SeqAccess<'de>,
61 {
62 let mut bytes = Vec::new();
63 while let Some(value) = seq.next_element::<u8>()? {
64 bytes.push(value);
65 }
66
67 Ok(Decimal::from(bytes))
68 }
69 }
70 deserializer.deserialize_bytes(DecimalVisitor)
71 }
72}
73
74impl PartialEq for Decimal {
77 fn eq(&self, other: &Self) -> bool {
78 self.value == other.value
79 }
80}
81
82impl Decimal {
83 pub(crate) fn len(&self) -> usize {
84 self.len
85 }
86
87 pub(crate) fn to_vec(&self) -> AvroResult<Vec<u8>> {
88 self.to_sign_extended_bytes_with_len(self.len)
89 }
90
91 pub(crate) fn to_sign_extended_bytes_with_len(&self, len: usize) -> AvroResult<Vec<u8>> {
92 let sign_byte = 0xFF * u8::from(self.value.sign() == Sign::Minus);
93 let mut decimal_bytes = vec![sign_byte; len];
94 let raw_bytes = self.value.to_signed_bytes_be();
95 let num_raw_bytes = raw_bytes.len();
96 let start_byte_index = len.checked_sub(num_raw_bytes).ok_or(Error::SignExtend {
97 requested: len,
98 needed: num_raw_bytes,
99 })?;
100 decimal_bytes[start_byte_index..].copy_from_slice(&raw_bytes);
101 Ok(decimal_bytes)
102 }
103}
104
105impl From<Decimal> for BigInt {
106 fn from(decimal: Decimal) -> Self {
107 decimal.value
108 }
109}
110
111impl std::convert::TryFrom<&Decimal> for Vec<u8> {
121 type Error = Error;
122
123 fn try_from(decimal: &Decimal) -> Result<Self, Self::Error> {
124 decimal.to_vec()
125 }
126}
127
128impl std::convert::TryFrom<Decimal> for Vec<u8> {
138 type Error = Error;
139
140 fn try_from(decimal: Decimal) -> Result<Self, Self::Error> {
141 decimal.to_vec()
142 }
143}
144
145impl<T: AsRef<[u8]>> From<T> for Decimal {
146 fn from(bytes: T) -> Self {
147 let bytes_ref = bytes.as_ref();
148 Self {
149 value: BigInt::from_signed_bytes_be(bytes_ref),
150 len: bytes_ref.len(),
151 }
152 }
153}
154
155#[cfg(test)]
156mod tests {
157 use super::*;
158 use apache_avro_test_helper::TestResult;
159 use pretty_assertions::assert_eq;
160
161 #[test]
162 fn test_decimal_from_bytes_from_ref_decimal() -> TestResult {
163 let input = vec![1, 24];
164 let d = Decimal::from(&input);
165
166 let output = <Vec<u8>>::try_from(&d)?;
167 assert_eq!(output, input);
168
169 Ok(())
170 }
171
172 #[test]
173 fn test_decimal_from_bytes_from_owned_decimal() -> TestResult {
174 let input = vec![1, 24];
175 let d = Decimal::from(&input);
176
177 let output = <Vec<u8>>::try_from(d)?;
178 assert_eq!(output, input);
179
180 Ok(())
181 }
182
183 #[test]
184 fn avro_3949_decimal_serde() -> TestResult {
185 let decimal = Decimal::from(&[1, 2, 3]);
186
187 let ser = serde_json::to_string(&decimal)?;
188 let de = serde_json::from_str(&ser)?;
189 std::assert_eq!(decimal, de);
190
191 Ok(())
192 }
193}