1use std::{borrow::Borrow, io::Read};
19
20use serde::{
21 Deserializer,
22 de::{DeserializeSeed, EnumAccess, Unexpected, VariantAccess, Visitor},
23};
24
25use super::{Config, DESERIALIZE_ANY, SchemaAwareDeserializer, identifier::IdentifierDeserializer};
26use crate::{
27 Error, Schema,
28 error::Details,
29 schema::{EnumSchema, UnionSchema},
30 util::zag_i32,
31};
32
33pub struct PlainEnumDeserializer<'s, 'r, R: Read> {
35 reader: &'r mut R,
36 symbols: &'s [String],
37}
38
39impl<'s, 'r, R: Read> PlainEnumDeserializer<'s, 'r, R> {
40 pub fn new(reader: &'r mut R, schema: &'s EnumSchema) -> Self {
41 Self {
42 symbols: &schema.symbols,
43 reader,
44 }
45 }
46}
47
48impl<'de, 's, 'r, R: Read> EnumAccess<'de> for PlainEnumDeserializer<'s, 'r, R> {
49 type Error = Error;
50 type Variant = Self;
51
52 fn variant_seed<V>(self, seed: V) -> Result<(V::Value, Self::Variant), Self::Error>
53 where
54 V: DeserializeSeed<'de>,
55 {
56 let index = zag_i32(self.reader)?;
57 let index = usize::try_from(index).map_err(|e| Details::ConvertI32ToUsize(e, index))?;
58 let symbol = self.symbols.get(index).ok_or(Details::EnumSymbolIndex {
59 index,
60 num_variants: self.symbols.len(),
61 })?;
62 Ok((
63 seed.deserialize(IdentifierDeserializer::string(symbol))?,
64 self,
65 ))
66 }
67}
68
69impl<'de, 's, 'r, R: Read> VariantAccess<'de> for PlainEnumDeserializer<'s, 'r, R> {
70 type Error = Error;
71
72 fn unit_variant(self) -> Result<(), Self::Error> {
73 Ok(())
74 }
75
76 fn newtype_variant_seed<T>(self, _: T) -> Result<T::Value, Self::Error>
77 where
78 T: DeserializeSeed<'de>,
79 {
80 let unexp = Unexpected::UnitVariant;
81 Err(serde::de::Error::invalid_type(unexp, &"newtype variant"))
82 }
83
84 fn tuple_variant<V>(self, _: usize, _: V) -> Result<V::Value, Self::Error>
85 where
86 V: Visitor<'de>,
87 {
88 let unexp = Unexpected::UnitVariant;
89 Err(serde::de::Error::invalid_type(unexp, &"tuple variant"))
90 }
91
92 fn struct_variant<V>(self, _: &'static [&'static str], _: V) -> Result<V::Value, Self::Error>
93 where
94 V: Visitor<'de>,
95 {
96 let unexp = Unexpected::UnitVariant;
97 Err(serde::de::Error::invalid_type(unexp, &"struct variant"))
98 }
99}
100
101pub struct UnionEnumDeserializer<'s, 'r, R: Read, S: Borrow<Schema>> {
102 reader: &'r mut R,
103 variants: &'s [Schema],
104 config: Config<'s, S>,
105}
106
107impl<'s, 'r, R: Read, S: Borrow<Schema>> UnionEnumDeserializer<'s, 'r, R, S> {
108 pub fn new(reader: &'r mut R, schema: &'s UnionSchema, config: Config<'s, S>) -> Self {
109 Self {
110 reader,
111 variants: schema.variants(),
112 config,
113 }
114 }
115}
116
117impl<'de, 's, 'r, R: Read, S: Borrow<Schema>> EnumAccess<'de>
118 for UnionEnumDeserializer<'s, 'r, R, S>
119{
120 type Error = Error;
121 type Variant = UnionVariantAccess<'s, 'r, R, S>;
122
123 fn variant_seed<V>(self, seed: V) -> Result<(V::Value, Self::Variant), Self::Error>
124 where
125 V: DeserializeSeed<'de>,
126 {
127 let index = zag_i32(self.reader)?;
128 let index = usize::try_from(index).map_err(|e| Details::ConvertI32ToUsize(e, index))?;
129 let schema = self.variants.get(index).ok_or(Details::GetUnionVariant {
130 index: index as i64,
131 num_variants: self.variants.len(),
132 })?;
133
134 Ok((
135 seed.deserialize(IdentifierDeserializer::index(index as u32))?,
136 UnionVariantAccess::new(schema, self.reader, self.config)?,
137 ))
138 }
139}
140
141pub struct UnionVariantAccess<'s, 'r, R: Read, S: Borrow<Schema>> {
142 schema: &'s Schema,
143 reader: &'r mut R,
144 config: Config<'s, S>,
145}
146
147impl<'s, 'r, R: Read, S: Borrow<Schema>> UnionVariantAccess<'s, 'r, R, S> {
148 pub fn new(
149 schema: &'s Schema,
150 reader: &'r mut R,
151 config: Config<'s, S>,
152 ) -> Result<Self, Error> {
153 let schema = if let Schema::Ref { name } = schema {
154 config.get_schema(name)?
155 } else {
156 schema
157 };
158 Ok(Self {
159 schema,
160 reader,
161 config,
162 })
163 }
164
165 fn error(&self, ty: &'static str, error: impl Into<String>) -> Error {
166 Error::new(Details::DeserializeSchemaAware {
167 value_type: ty,
168 value: error.into(),
169 schema: self.schema.clone(),
170 })
171 }
172}
173
174impl<'de, 's, 'r, R: Read, S: Borrow<Schema>> VariantAccess<'de>
175 for UnionVariantAccess<'s, 'r, R, S>
176{
177 type Error = Error;
178
179 fn unit_variant(self) -> Result<(), Self::Error> {
180 match self.schema {
181 Schema::Null => Ok(()),
182 Schema::Record(record) if record.fields.is_empty() => Ok(()),
183 _ => Err(self.error(
184 "unit variant",
185 "Expected Schema::Null | Schema::Record(fields.len() == 0)",
186 )),
187 }
188 }
189
190 fn newtype_variant_seed<T>(self, seed: T) -> Result<T::Value, Self::Error>
191 where
192 T: DeserializeSeed<'de>,
193 {
194 match self.schema {
195 Schema::Record(record)
196 if record.fields.len() == 1
197 && record
198 .attributes
199 .get("org.apache.avro.rust.union_of_records")
200 == Some(&serde_json::Value::Bool(true)) =>
201 {
202 seed.deserialize(SchemaAwareDeserializer::new(
203 self.reader,
204 &record.fields[0].schema,
205 self.config,
206 )?)
207 }
208 _ => seed.deserialize(SchemaAwareDeserializer::new(
209 self.reader,
210 self.schema,
211 self.config,
212 )?),
213 }
214 }
215
216 fn tuple_variant<V>(self, len: usize, visitor: V) -> Result<V::Value, Self::Error>
217 where
218 V: Visitor<'de>,
219 {
220 SchemaAwareDeserializer::new(self.reader, self.schema, self.config)?
221 .deserialize_tuple(len, visitor)
222 }
223
224 fn struct_variant<V>(
225 self,
226 fields: &'static [&'static str],
227 visitor: V,
228 ) -> Result<V::Value, Self::Error>
229 where
230 V: Visitor<'de>,
231 {
232 SchemaAwareDeserializer::new(self.reader, self.schema, self.config)?.deserialize_struct(
233 DESERIALIZE_ANY,
234 fields,
235 visitor,
236 )
237 }
238}