Skip to main content

apache_avro/serde/deser_schema/
enums.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::{borrow::Borrow, io::Read};
19
20use serde::{
21    Deserializer,
22    de::{DeserializeSeed, EnumAccess, Unexpected, VariantAccess, Visitor},
23};
24
25use super::{Config, DESERIALIZE_ANY, SchemaAwareDeserializer, identifier::IdentifierDeserializer};
26use crate::{
27    Error, Schema,
28    error::Details,
29    schema::{EnumSchema, UnionSchema},
30    util::zag_i32,
31};
32
33/// Deserializer for plain enums.
34pub struct PlainEnumDeserializer<'s, 'r, R: Read> {
35    reader: &'r mut R,
36    symbols: &'s [String],
37}
38
39impl<'s, 'r, R: Read> PlainEnumDeserializer<'s, 'r, R> {
40    pub fn new(reader: &'r mut R, schema: &'s EnumSchema) -> Self {
41        Self {
42            symbols: &schema.symbols,
43            reader,
44        }
45    }
46}
47
48impl<'de, 's, 'r, R: Read> EnumAccess<'de> for PlainEnumDeserializer<'s, 'r, R> {
49    type Error = Error;
50    type Variant = Self;
51
52    fn variant_seed<V>(self, seed: V) -> Result<(V::Value, Self::Variant), Self::Error>
53    where
54        V: DeserializeSeed<'de>,
55    {
56        let index = zag_i32(self.reader)?;
57        let index = usize::try_from(index).map_err(|e| Details::ConvertI32ToUsize(e, index))?;
58        let symbol = self.symbols.get(index).ok_or(Details::EnumSymbolIndex {
59            index,
60            num_variants: self.symbols.len(),
61        })?;
62        Ok((
63            seed.deserialize(IdentifierDeserializer::string(symbol))?,
64            self,
65        ))
66    }
67}
68
69impl<'de, 's, 'r, R: Read> VariantAccess<'de> for PlainEnumDeserializer<'s, 'r, R> {
70    type Error = Error;
71
72    fn unit_variant(self) -> Result<(), Self::Error> {
73        Ok(())
74    }
75
76    fn newtype_variant_seed<T>(self, _: T) -> Result<T::Value, Self::Error>
77    where
78        T: DeserializeSeed<'de>,
79    {
80        let unexp = Unexpected::UnitVariant;
81        Err(serde::de::Error::invalid_type(unexp, &"newtype variant"))
82    }
83
84    fn tuple_variant<V>(self, _: usize, _: V) -> Result<V::Value, Self::Error>
85    where
86        V: Visitor<'de>,
87    {
88        let unexp = Unexpected::UnitVariant;
89        Err(serde::de::Error::invalid_type(unexp, &"tuple variant"))
90    }
91
92    fn struct_variant<V>(self, _: &'static [&'static str], _: V) -> Result<V::Value, Self::Error>
93    where
94        V: Visitor<'de>,
95    {
96        let unexp = Unexpected::UnitVariant;
97        Err(serde::de::Error::invalid_type(unexp, &"struct variant"))
98    }
99}
100
101pub struct UnionEnumDeserializer<'s, 'r, R: Read, S: Borrow<Schema>> {
102    reader: &'r mut R,
103    variants: &'s [Schema],
104    config: Config<'s, S>,
105}
106
107impl<'s, 'r, R: Read, S: Borrow<Schema>> UnionEnumDeserializer<'s, 'r, R, S> {
108    pub fn new(reader: &'r mut R, schema: &'s UnionSchema, config: Config<'s, S>) -> Self {
109        Self {
110            reader,
111            variants: schema.variants(),
112            config,
113        }
114    }
115}
116
117impl<'de, 's, 'r, R: Read, S: Borrow<Schema>> EnumAccess<'de>
118    for UnionEnumDeserializer<'s, 'r, R, S>
119{
120    type Error = Error;
121    type Variant = UnionVariantAccess<'s, 'r, R, S>;
122
123    fn variant_seed<V>(self, seed: V) -> Result<(V::Value, Self::Variant), Self::Error>
124    where
125        V: DeserializeSeed<'de>,
126    {
127        let index = zag_i32(self.reader)?;
128        let index = usize::try_from(index).map_err(|e| Details::ConvertI32ToUsize(e, index))?;
129        let schema = self.variants.get(index).ok_or(Details::GetUnionVariant {
130            index: index as i64,
131            num_variants: self.variants.len(),
132        })?;
133
134        Ok((
135            seed.deserialize(IdentifierDeserializer::index(index as u32))?,
136            UnionVariantAccess::new(schema, self.reader, self.config)?,
137        ))
138    }
139}
140
141pub struct UnionVariantAccess<'s, 'r, R: Read, S: Borrow<Schema>> {
142    schema: &'s Schema,
143    reader: &'r mut R,
144    config: Config<'s, S>,
145}
146
147impl<'s, 'r, R: Read, S: Borrow<Schema>> UnionVariantAccess<'s, 'r, R, S> {
148    pub fn new(
149        schema: &'s Schema,
150        reader: &'r mut R,
151        config: Config<'s, S>,
152    ) -> Result<Self, Error> {
153        let schema = if let Schema::Ref { name } = schema {
154            config.get_schema(name)?
155        } else {
156            schema
157        };
158        Ok(Self {
159            schema,
160            reader,
161            config,
162        })
163    }
164
165    fn error(&self, ty: &'static str, error: impl Into<String>) -> Error {
166        Error::new(Details::DeserializeSchemaAware {
167            value_type: ty,
168            value: error.into(),
169            schema: self.schema.clone(),
170        })
171    }
172}
173
174impl<'de, 's, 'r, R: Read, S: Borrow<Schema>> VariantAccess<'de>
175    for UnionVariantAccess<'s, 'r, R, S>
176{
177    type Error = Error;
178
179    fn unit_variant(self) -> Result<(), Self::Error> {
180        match self.schema {
181            Schema::Null => Ok(()),
182            Schema::Record(record) if record.fields.is_empty() => Ok(()),
183            _ => Err(self.error(
184                "unit variant",
185                "Expected Schema::Null | Schema::Record(fields.len() == 0)",
186            )),
187        }
188    }
189
190    fn newtype_variant_seed<T>(self, seed: T) -> Result<T::Value, Self::Error>
191    where
192        T: DeserializeSeed<'de>,
193    {
194        match self.schema {
195            Schema::Record(record)
196                if record.fields.len() == 1
197                    && record
198                        .attributes
199                        .get("org.apache.avro.rust.union_of_records")
200                        == Some(&serde_json::Value::Bool(true)) =>
201            {
202                seed.deserialize(SchemaAwareDeserializer::new(
203                    self.reader,
204                    &record.fields[0].schema,
205                    self.config,
206                )?)
207            }
208            _ => seed.deserialize(SchemaAwareDeserializer::new(
209                self.reader,
210                self.schema,
211                self.config,
212            )?),
213        }
214    }
215
216    fn tuple_variant<V>(self, len: usize, visitor: V) -> Result<V::Value, Self::Error>
217    where
218        V: Visitor<'de>,
219    {
220        SchemaAwareDeserializer::new(self.reader, self.schema, self.config)?
221            .deserialize_tuple(len, visitor)
222    }
223
224    fn struct_variant<V>(
225        self,
226        fields: &'static [&'static str],
227        visitor: V,
228    ) -> Result<V::Value, Self::Error>
229    where
230        V: Visitor<'de>,
231    {
232        SchemaAwareDeserializer::new(self.reader, self.schema, self.config)?.deserialize_struct(
233            DESERIALIZE_ANY,
234            fields,
235            visitor,
236        )
237    }
238}