Skip to main content

apache_avro/serde/ser_schema/
union.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::{borrow::Borrow, io::Write};
19
20use serde::{Serialize, Serializer};
21
22use super::{Config, MapOrRecordSerializer, SchemaAwareSerializer};
23use crate::{
24    Error, Schema,
25    error::Details,
26    schema::{FixedSchema, SchemaKind, UnionSchema},
27    serde::{
28        ser_schema::{
29            block::BlockSerializer,
30            record::RecordSerializer,
31            tuple::{ManyTupleSerializer, TupleSerializer},
32        },
33        with::{BytesType, SER_BYTES_TYPE},
34    },
35    util::{zig_i32, zig_i64},
36};
37
38/// Serializer that finds the right union variant for the type being serialized.
39///
40/// `serialize_*_variant`, `serialize_some`, and `serialize_none` will all return an error as nested
41/// unions are invalid in the Avro specification.
42pub struct UnionSerializer<'s, 'w, W: Write, S: Borrow<Schema>> {
43    writer: &'w mut W,
44    union: &'s UnionSchema,
45    config: Config<'s, S>,
46}
47
48impl<'s, 'w, W: Write, S: Borrow<Schema>> UnionSerializer<'s, 'w, W, S> {
49    pub fn new(writer: &'w mut W, union: &'s UnionSchema, config: Config<'s, S>) -> Self {
50        UnionSerializer {
51            writer,
52            union,
53            config,
54        }
55    }
56
57    fn error(&self, ty: &'static str, error: impl Into<String>) -> Error {
58        Error::new(Details::SerializeValueWithSchema {
59            value_type: ty,
60            value: error.into(),
61            schema: Schema::Union(self.union.clone()),
62        })
63    }
64
65    /// Write an integer to the writer.
66    ///
67    /// This will check that the current schema is [`Schema::Int`] or a logical type based on that.
68    /// This will write the union index.
69    pub(super) fn checked_write_int(
70        &mut self,
71        original_ty: &'static str,
72        v: i32,
73    ) -> Result<usize, Error> {
74        if let Some(index) = self.union.index_of_schema_kind(SchemaKind::Int) {
75            let mut bytes_written = zig_i32(index as i32, &mut *self.writer)?;
76            bytes_written += zig_i32(v, &mut *self.writer)?;
77            Ok(bytes_written)
78        } else {
79            Err(self.error(
80                original_ty,
81                "Expected Schema::Int | Schema::Date | Schema::TimeMillis in variants",
82            ))
83        }
84    }
85
86    /// Write a long to the writer.
87    ///
88    /// This will check that the current schema is [`Schema::Long`] or a logical type based on that.
89    /// This will write the union index.
90    pub(super) fn checked_write_long(
91        &mut self,
92        original_ty: &'static str,
93        v: i64,
94    ) -> Result<usize, Error> {
95        if let Some(index) = self.union.index_of_schema_kind(SchemaKind::Long) {
96            let mut bytes_written = zig_i32(index as i32, &mut *self.writer)?;
97            bytes_written += zig_i64(v, &mut *self.writer)?;
98            Ok(bytes_written)
99        } else {
100            Err(self.error(original_ty, "Expected Schema::Long | Schema::TimeMicros | Schema::{,Local}Timestamp{Millis,Micros,Nanos} in variants"))
101        }
102    }
103
104    /// Write bytes to the writer with preceding length header.
105    ///
106    /// This does not check the current schema and does not write the union index.
107    fn write_bytes_with_len(&mut self, bytes: &[u8]) -> Result<usize, Error> {
108        let mut bytes_written = 0;
109        bytes_written += zig_i64(bytes.len() as i64, &mut *self.writer)?;
110        bytes_written += self.write_bytes(bytes)?;
111        Ok(bytes_written)
112    }
113
114    /// Write bytes to the writer.
115    ///
116    /// This does not check the current schema and does not write the union index.
117    fn write_bytes(&mut self, bytes: &[u8]) -> Result<usize, Error> {
118        self.writer.write_all(bytes).map_err(Details::WriteBytes)?;
119        Ok(bytes.len())
120    }
121
122    /// Write an array of `n` bytes to the writer.
123    ///
124    /// This does not check the current schema and does not write the union index.
125    fn write_array<const N: usize>(&mut self, bytes: [u8; N]) -> Result<usize, Error> {
126        self.write_bytes(&bytes)?;
127        Ok(N)
128    }
129}
130
131impl<'s, 'w, W: Write, S: Borrow<Schema>> Serializer for UnionSerializer<'s, 'w, W, S> {
132    type Ok = usize;
133    type Error = Error;
134    type SerializeSeq = BlockSerializer<'s, 'w, W, S>;
135    type SerializeTuple = TupleSerializer<'s, 'w, W, S>;
136    type SerializeTupleStruct = ManyTupleSerializer<'s, 'w, W, S>;
137    type SerializeTupleVariant = ManyTupleSerializer<'s, 'w, W, S>;
138    type SerializeMap = MapOrRecordSerializer<'s, 'w, W, S>;
139    type SerializeStruct = RecordSerializer<'s, 'w, W, S>;
140    type SerializeStructVariant = RecordSerializer<'s, 'w, W, S>;
141
142    fn serialize_bool(mut self, v: bool) -> Result<Self::Ok, Self::Error> {
143        if let Some(index) = self.union.index_of_schema_kind(SchemaKind::Boolean) {
144            let mut bytes_written = zig_i32(index as i32, &mut *self.writer)?;
145            bytes_written += self.write_array([u8::from(v)])?;
146            Ok(bytes_written)
147        } else {
148            Err(self.error("bool", "Expected Schema::Boolean in variants"))
149        }
150    }
151
152    fn serialize_i8(mut self, v: i8) -> Result<Self::Ok, Self::Error> {
153        self.checked_write_int("i8", i32::from(v))
154    }
155
156    fn serialize_i16(mut self, v: i16) -> Result<Self::Ok, Self::Error> {
157        self.checked_write_int("i16", i32::from(v))
158    }
159
160    fn serialize_i32(mut self, v: i32) -> Result<Self::Ok, Self::Error> {
161        self.checked_write_int("i32", v)
162    }
163
164    fn serialize_i64(mut self, v: i64) -> Result<Self::Ok, Self::Error> {
165        self.checked_write_long("i64", v)
166    }
167
168    fn serialize_i128(mut self, v: i128) -> Result<Self::Ok, Self::Error> {
169        match self.union.find_named_schema("i128", self.config.names)? {
170            Some((index, Schema::Fixed(FixedSchema { size: 16, .. }))) => {
171                let mut bytes_written = zig_i32(index as i32, &mut *self.writer)?;
172                bytes_written += self.write_array(v.to_le_bytes())?;
173                Ok(bytes_written)
174            }
175            _ => Err(self.error(
176                "i128",
177                r#"Expected Schema::Fixed(name: "i128", size: 16) in variants"#,
178            )),
179        }
180    }
181
182    fn serialize_u8(mut self, v: u8) -> Result<Self::Ok, Self::Error> {
183        self.checked_write_int("u8", i32::from(v))
184    }
185
186    fn serialize_u16(mut self, v: u16) -> Result<Self::Ok, Self::Error> {
187        self.checked_write_int("u16", i32::from(v))
188    }
189
190    fn serialize_u32(mut self, v: u32) -> Result<Self::Ok, Self::Error> {
191        self.checked_write_long("u32", i64::from(v))
192    }
193
194    fn serialize_u64(mut self, v: u64) -> Result<Self::Ok, Self::Error> {
195        match self.union.find_named_schema("u64", self.config.names)? {
196            Some((index, Schema::Fixed(FixedSchema { size: 8, .. }))) => {
197                let mut bytes_written = zig_i32(index as i32, &mut *self.writer)?;
198                bytes_written += self.write_array(v.to_le_bytes())?;
199                Ok(bytes_written)
200            }
201            _ => Err(self.error(
202                "u64",
203                r#"Expected Schema::Fixed(name: "u64", size: 8) in variants"#,
204            )),
205        }
206    }
207
208    fn serialize_u128(mut self, v: u128) -> Result<Self::Ok, Self::Error> {
209        match self.union.find_named_schema("u128", self.config.names)? {
210            Some((index, Schema::Fixed(FixedSchema { size: 16, .. }))) => {
211                let mut bytes_written = zig_i32(index as i32, &mut *self.writer)?;
212                bytes_written += self.write_array(v.to_le_bytes())?;
213                Ok(bytes_written)
214            }
215            _ => Err(self.error(
216                "u128",
217                r#"Expected Schema::Fixed(name: "u128", size: 16) in variants"#,
218            )),
219        }
220    }
221
222    fn serialize_f32(mut self, v: f32) -> Result<Self::Ok, Self::Error> {
223        if let Some(index) = self.union.index_of_schema_kind(SchemaKind::Float) {
224            let mut bytes_written = zig_i32(index as i32, &mut *self.writer)?;
225            bytes_written += self.write_array(v.to_le_bytes())?;
226            Ok(bytes_written)
227        } else {
228            Err(self.error("f32", "Expected Schema::Float in variants"))
229        }
230    }
231
232    fn serialize_f64(mut self, v: f64) -> Result<Self::Ok, Self::Error> {
233        if let Some(index) = self.union.index_of_schema_kind(SchemaKind::Double) {
234            let mut bytes_written = zig_i32(index as i32, &mut *self.writer)?;
235            bytes_written += self.write_array(v.to_le_bytes())?;
236            Ok(bytes_written)
237        } else {
238            Err(self.error("f64", "Expected Schema::Double in variants"))
239        }
240    }
241
242    fn serialize_char(mut self, v: char) -> Result<Self::Ok, Self::Error> {
243        if let Some(index) = self.union.index_of_schema_kind(SchemaKind::String) {
244            let mut bytes_written = zig_i32(index as i32, &mut *self.writer)?;
245            bytes_written += self.write_bytes_with_len(v.to_string().as_bytes())?;
246            Ok(bytes_written)
247        } else {
248            Err(self.error("char", "Expected Schema::String in variants"))
249        }
250    }
251
252    fn serialize_str(mut self, v: &str) -> Result<Self::Ok, Self::Error> {
253        if let Some(index) = self.union.index_of_schema_kind(SchemaKind::String) {
254            let mut bytes_written = zig_i32(index as i32, &mut *self.writer)?;
255            bytes_written += self.write_bytes_with_len(v.as_bytes())?;
256            Ok(bytes_written)
257        } else {
258            Err(self.error("str", "Expected Schema::String in variants"))
259        }
260    }
261
262    fn serialize_bytes(mut self, v: &[u8]) -> Result<Self::Ok, Self::Error> {
263        let (index, with_len) = match SER_BYTES_TYPE.get() {
264            BytesType::Bytes => {
265                if let Some(index) = self.union.index_of_schema_kind(SchemaKind::Bytes) {
266                    (index, true)
267                } else {
268                    return Err(self.error("bytes", "Expected Schema::Bytes in variants"));
269                }
270            }
271            BytesType::Fixed => {
272                if let Some((index, _)) = self
273                    .union
274                    .find_fixed_of_size_n(v.len(), self.config.names)?
275                {
276                    (index, false)
277                } else {
278                    return Err(self.error(
279                        "bytes",
280                        format!("Expected Schema::Fixed(size: {}) in variants", v.len()),
281                    ));
282                }
283            }
284            BytesType::Unset => {
285                let bytes_index = self.union.index_of_schema_kind(SchemaKind::Bytes);
286                let fixed_index = self
287                    .union
288                    .find_fixed_of_size_n(v.len(), self.config.names)?;
289                // Find the first variant that matches the bytes or fixed
290                match (bytes_index, fixed_index) {
291                    (Some(bytes_index), Some((fixed_index, _))) => {
292                        (bytes_index.min(fixed_index), bytes_index < fixed_index)
293                    }
294                    (Some(bytes_index), None) => (bytes_index, true),
295                    (None, Some((fixed_index, _))) => (fixed_index, false),
296                    (None, None) => {
297                        return Err(self.error(
298                            "bytes",
299                            format!(
300                                "Expected Schema::Bytes | Schema::Fixed(size: {}) in variants",
301                                v.len()
302                            ),
303                        ));
304                    }
305                }
306            }
307        };
308        let mut bytes_written = zig_i32(index as i32, &mut *self.writer)?;
309        if with_len {
310            bytes_written += self.write_bytes_with_len(v)?;
311        } else {
312            bytes_written += self.write_bytes(v)?;
313        }
314        Ok(bytes_written)
315    }
316
317    fn serialize_none(self) -> Result<Self::Ok, Self::Error> {
318        Err(self.error("none", "Nested unions are not supported"))
319    }
320
321    fn serialize_some<T>(self, _: &T) -> Result<Self::Ok, Self::Error>
322    where
323        T: ?Sized + Serialize,
324    {
325        Err(self.error("some", "Nested unions are not supported"))
326    }
327
328    fn serialize_unit(self) -> Result<Self::Ok, Self::Error> {
329        if let Some(index) = self.union.index_of_schema_kind(SchemaKind::Null) {
330            zig_i32(index as i32, &mut *self.writer)
331        } else {
332            Err(self.error("unit", "Expected Schema::Null in variants"))
333        }
334    }
335
336    fn serialize_unit_struct(self, name: &'static str) -> Result<Self::Ok, Self::Error> {
337        match self.union.find_named_schema(name, self.config.names)? {
338            Some((index, Schema::Record(record))) if record.fields.is_empty() => {
339                zig_i32(index as i32, &mut *self.writer)
340            }
341            _ => Err(self.error(
342                "unit struct",
343                format!("Expected Schema::Record(name: {name}, fields: []) in variants"),
344            )),
345        }
346    }
347
348    fn serialize_unit_variant(
349        self,
350        _: &'static str,
351        _: u32,
352        _: &'static str,
353    ) -> Result<Self::Ok, Self::Error> {
354        Err(self.error("unit variant", "Nested unions are not supported"))
355    }
356
357    fn serialize_newtype_struct<T>(
358        self,
359        name: &'static str,
360        value: &T,
361    ) -> Result<Self::Ok, Self::Error>
362    where
363        T: ?Sized + Serialize,
364    {
365        match self.union.find_named_schema(name, self.config.names)? {
366            Some((index, Schema::Record(record))) if record.fields.len() == 1 => {
367                let mut bytes_written = zig_i32(index as i32, &mut *self.writer)?;
368                bytes_written += value.serialize(SchemaAwareSerializer::new(
369                    self.writer,
370                    &record.fields[0].schema,
371                    self.config,
372                )?)?;
373                Ok(bytes_written)
374            }
375            _ => Err(self.error(
376                "newtype struct",
377                format!("Expected Schema::Record(name: {name}, fields: [_]) in variants"),
378            )),
379        }
380    }
381
382    fn serialize_newtype_variant<T>(
383        self,
384        _: &'static str,
385        _: u32,
386        _: &'static str,
387        _: &T,
388    ) -> Result<Self::Ok, Self::Error>
389    where
390        T: ?Sized + Serialize,
391    {
392        Err(self.error("newtype variant", "Nested unions are not supported"))
393    }
394
395    fn serialize_seq(self, len: Option<usize>) -> Result<Self::SerializeSeq, Self::Error> {
396        if let Some(index) = self.union.index_of_schema_kind(SchemaKind::Array)
397            && let Schema::Array(array) = &self.union.variants()[index]
398        {
399            let bytes_written = zig_i32(index as i32, &mut *self.writer)?;
400            BlockSerializer::array(self.writer, array, self.config, len, Some(bytes_written))
401        } else {
402            Err(self.error("array", "Expected Schema::Array in variants"))
403        }
404    }
405
406    fn serialize_tuple(self, len: usize) -> Result<Self::SerializeTuple, Self::Error> {
407        if len == 0 {
408            if let Some(index) = self.union.index_of_schema_kind(SchemaKind::Null) {
409                let bytes_written = zig_i32(index as i32, &mut *self.writer)?;
410                Ok(TupleSerializer::unit(Some(bytes_written)))
411            } else {
412                Err(self.error("tuple", "Expected Schema::Null in variants for 0-tuple"))
413            }
414        } else if len == 1 {
415            Ok(TupleSerializer::one_union(
416                self.writer,
417                self.union,
418                self.config,
419            ))
420        } else if let Some((index, record)) = self
421            .union
422            .find_record_with_n_fields(len, self.config.names)?
423        {
424            let bytes_written = zig_i32(index as i32, &mut *self.writer)?;
425            Ok(TupleSerializer::many(
426                self.writer,
427                record,
428                self.config,
429                Some(bytes_written),
430            ))
431        } else {
432            Err(self.error(
433                "tuple",
434                format!(
435                    "Expected Schema::Record(fields.len() == {len}) in variants for {len}-tuple"
436                ),
437            ))
438        }
439    }
440
441    fn serialize_tuple_struct(
442        self,
443        name: &'static str,
444        len: usize,
445    ) -> Result<Self::SerializeTupleStruct, Self::Error> {
446        match self.union.find_named_schema(name, self.config.names)? {
447            Some((index, Schema::Record(record))) if record.fields.len() == len => {
448                let bytes_written = zig_i32(index as i32, &mut *self.writer)?;
449                Ok(ManyTupleSerializer::new(
450                    self.writer,
451                    record,
452                    self.config,
453                    Some(bytes_written),
454                ))
455            }
456            _ => Err(self.error(
457                "tuple struct",
458                format!("Expected Schema::Record(name: {name}, fields.len() == {len}) in variants"),
459            )),
460        }
461    }
462
463    fn serialize_tuple_variant(
464        self,
465        _: &'static str,
466        _: u32,
467        _: &'static str,
468        _: usize,
469    ) -> Result<Self::SerializeTupleVariant, Self::Error> {
470        Err(self.error("tuple variant", "Nested unions are not supported"))
471    }
472
473    fn serialize_map(self, len: Option<usize>) -> Result<Self::SerializeMap, Self::Error> {
474        let map_index = self.union.index_of_schema_kind(SchemaKind::Map).map(|i| {
475            if let Schema::Map(map) = &self.union.variants()[i] {
476                (i, map)
477            } else {
478                unreachable!("SchemaKind is Map so Schema must also be a Map")
479            }
480        });
481        let record_index = if let Some(len) = len {
482            self.union
483                .find_record_with_n_fields(len, self.config.names)?
484        } else {
485            None
486        };
487        match (map_index, record_index) {
488            (Some((map_index, map)), Some((record_index, record))) => {
489                let bytes_written = zig_i32(map_index.min(record_index) as i32, &mut *self.writer)?;
490                if map_index < record_index {
491                    MapOrRecordSerializer::map(
492                        self.writer,
493                        map,
494                        self.config,
495                        len,
496                        Some(bytes_written),
497                    )
498                } else {
499                    Ok(MapOrRecordSerializer::record(
500                        self.writer,
501                        record,
502                        self.config,
503                        Some(bytes_written),
504                    ))
505                }
506            }
507            (Some((map_index, map)), None) => {
508                let bytes_written = zig_i32(map_index as i32, &mut *self.writer)?;
509                MapOrRecordSerializer::map(self.writer, map, self.config, len, Some(bytes_written))
510            }
511            (None, Some((record_index, record))) => {
512                let bytes_written = zig_i32(record_index as i32, &mut *self.writer)?;
513                Ok(MapOrRecordSerializer::record(
514                    self.writer,
515                    record,
516                    self.config,
517                    Some(bytes_written),
518                ))
519            }
520            (None, None) => Err(self.error(
521                "map",
522                "Expected Schema::Map or Schema::Record for structs with flattened fields in variants",
523            )),
524        }
525    }
526
527    fn serialize_struct(
528        self,
529        name: &'static str,
530        _len: usize,
531    ) -> Result<Self::SerializeStruct, Self::Error> {
532        if let Some((index, Schema::Record(record))) =
533            self.union.find_named_schema(name, self.config.names)?
534        {
535            let bytes_written = zig_i32(index as i32, &mut *self.writer)?;
536            Ok(RecordSerializer::new(
537                self.writer,
538                record,
539                self.config,
540                Some(bytes_written),
541            ))
542        } else {
543            Err(self.error(
544                "struct",
545                format!("Expected Schema::Record(name: {name}) in variants"),
546            ))
547        }
548    }
549
550    fn serialize_struct_variant(
551        self,
552        _: &'static str,
553        _: u32,
554        _: &'static str,
555        _: usize,
556    ) -> Result<Self::SerializeStructVariant, Self::Error> {
557        Err(self.error("struct variant", "Nested unions are not supported"))
558    }
559
560    fn is_human_readable(&self) -> bool {
561        self.config.human_readable
562    }
563}