apache_avro/schema/
union.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use crate::error::Details;
19use crate::schema::{
20    DecimalSchema, InnerDecimalSchema, Name, NamespaceRef, Schema, SchemaKind, UuidSchema,
21};
22use crate::types;
23use crate::{AvroResult, Error};
24use std::borrow::Borrow;
25use std::collections::{BTreeMap, HashMap};
26use std::fmt::{Debug, Formatter};
27use strum::IntoDiscriminant;
28
29/// A description of a Union schema
30#[derive(Clone)]
31pub struct UnionSchema {
32    /// The schemas that make up this union
33    pub(crate) schemas: Vec<Schema>,
34    /// The indexes of unnamed types.
35    ///
36    /// Logical types have been reduced to their inner type.
37    variant_index: BTreeMap<SchemaKind, usize>,
38    /// The indexes of named types.
39    ///
40    /// The names themselves aren't saved as they aren't used.
41    named_index: Vec<usize>,
42}
43
44impl Debug for UnionSchema {
45    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
46        // Doesn't include `variant_index` as it's a derivative of `schemas`
47        f.debug_struct("UnionSchema")
48            .field("schemas", &self.schemas)
49            .finish()
50    }
51}
52
53impl UnionSchema {
54    /// Creates a new `UnionSchema` from a vector of schemas.
55    ///
56    /// # Errors
57    /// Will return an error if `schemas` has duplicate unnamed schemas or if `schemas`
58    /// contains a union.
59    pub fn new(schemas: Vec<Schema>) -> AvroResult<Self> {
60        let mut builder = Self::builder();
61        for schema in schemas {
62            builder.variant(schema)?;
63        }
64        Ok(builder.build())
65    }
66
67    /// Build a `UnionSchema` piece-by-piece.
68    pub fn builder() -> UnionSchemaBuilder {
69        UnionSchemaBuilder::new()
70    }
71
72    /// Returns a slice to all variants of this schema.
73    pub fn variants(&self) -> &[Schema] {
74        &self.schemas
75    }
76
77    /// Returns true if the any of the variants of this `UnionSchema` is `Null`.
78    pub fn is_nullable(&self) -> bool {
79        self.variant_index.contains_key(&SchemaKind::Null)
80    }
81
82    /// Optionally returns a reference to the schema matched by this value, as well as its position
83    /// within this union.
84    ///
85    /// Extra arguments:
86    /// - `known_schemata` - mapping between `Name` and `Schema` - if passed, additional external schemas would be used to resolve references.
87    pub fn find_schema_with_known_schemata<S: Borrow<Schema> + Debug>(
88        &self,
89        value: &types::Value,
90        known_schemata: Option<&HashMap<Name, S>>,
91        enclosing_namespace: NamespaceRef,
92    ) -> Option<(usize, &Schema)> {
93        let known_schemata_if_none = HashMap::new();
94        let known_schemata = known_schemata.unwrap_or(&known_schemata_if_none);
95        let ValueSchemaKind { unnamed, named } = Self::value_to_base_schemakind(value);
96        // Unnamed schema types can be looked up directly using the variant_index
97        let unnamed = unnamed
98            .and_then(|kind| self.variant_index.get(&kind).copied())
99            .map(|index| (index, &self.schemas[index]))
100            .and_then(|(index, schema)| {
101                let kind = schema.discriminant();
102                // Maps and arrays need to be checked if they actually match the value
103                if kind == SchemaKind::Map || kind == SchemaKind::Array {
104                    let namespace = schema.namespace().or(enclosing_namespace);
105
106                    // TODO: Do this without the clone
107                    value
108                        .clone()
109                        .resolve_internal(schema, known_schemata, namespace, &None)
110                        .ok()
111                        .map(|_| (index, schema))
112                } else {
113                    Some((index, schema))
114                }
115            });
116        let named = named.and_then(|kind| {
117            // Every named type needs to be checked against a value until one matches
118
119            self.named_index
120                .iter()
121                .copied()
122                .map(|i| (i, &self.schemas[i]))
123                .filter(|(_i, s)| {
124                    let s_kind = schema_to_base_schemakind(s);
125                    s_kind == kind || s_kind == SchemaKind::Ref
126                })
127                .find(|(_i, schema)| {
128                    let namespace = schema.namespace().or(enclosing_namespace);
129
130                    // TODO: Do this without the clone
131                    value
132                        .clone()
133                        .resolve_internal(schema, known_schemata, namespace, &None)
134                        .is_ok()
135                })
136        });
137
138        match (unnamed, named) {
139            (Some((u_i, _)), Some((n_i, _))) if u_i < n_i => unnamed,
140            (Some(_), Some(_)) => named,
141            (Some(_), None) => unnamed,
142            (None, Some(_)) => named,
143            (None, None) => {
144                // Slow path, check if value can be promoted to any of the types in the union
145                self.schemas.iter().enumerate().find(|(_i, schema)| {
146                    let namespace = schema.namespace().or(enclosing_namespace);
147
148                    // TODO: Do this without the clone
149                    value
150                        .clone()
151                        .resolve_internal(schema, known_schemata, namespace, &None)
152                        .is_ok()
153                })
154            }
155        }
156    }
157
158    /// Convert a value to a [`SchemaKind`] stripping logical types to their base type.
159    fn value_to_base_schemakind(value: &types::Value) -> ValueSchemaKind {
160        let schemakind = SchemaKind::from(value);
161        match schemakind {
162            SchemaKind::Decimal => ValueSchemaKind {
163                unnamed: Some(SchemaKind::Bytes),
164                named: Some(SchemaKind::Fixed),
165            },
166            SchemaKind::BigDecimal => ValueSchemaKind {
167                unnamed: Some(SchemaKind::Bytes),
168                named: None,
169            },
170            SchemaKind::Uuid => ValueSchemaKind {
171                unnamed: Some(SchemaKind::String),
172                named: Some(SchemaKind::Fixed),
173            },
174            SchemaKind::Date | SchemaKind::TimeMillis => ValueSchemaKind {
175                unnamed: Some(SchemaKind::Int),
176                named: None,
177            },
178            SchemaKind::TimeMicros
179            | SchemaKind::TimestampMillis
180            | SchemaKind::TimestampMicros
181            | SchemaKind::TimestampNanos
182            | SchemaKind::LocalTimestampMillis
183            | SchemaKind::LocalTimestampMicros
184            | SchemaKind::LocalTimestampNanos => ValueSchemaKind {
185                unnamed: Some(SchemaKind::Long),
186                named: None,
187            },
188            SchemaKind::Duration => ValueSchemaKind {
189                unnamed: None,
190                named: Some(SchemaKind::Fixed),
191            },
192            SchemaKind::Record | SchemaKind::Enum | SchemaKind::Fixed => ValueSchemaKind {
193                unnamed: None,
194                named: Some(schemakind),
195            },
196            // When a `serde_json::Value` is converted to a `types::Value` a object will always become a map
197            // so a `types::Value::Map` can also be a record.
198            SchemaKind::Map => ValueSchemaKind {
199                unnamed: Some(SchemaKind::Map),
200                named: Some(SchemaKind::Record),
201            },
202            _ => ValueSchemaKind {
203                unnamed: Some(schemakind),
204                named: None,
205            },
206        }
207    }
208}
209
210/// The schema kinds matching a specific value.
211struct ValueSchemaKind {
212    unnamed: Option<SchemaKind>,
213    named: Option<SchemaKind>,
214}
215
216// No need to compare variant_index, it is derivative of schemas.
217impl PartialEq for UnionSchema {
218    fn eq(&self, other: &UnionSchema) -> bool {
219        self.schemas.eq(&other.schemas)
220    }
221}
222
223/// A builder for [`UnionSchema`]
224#[derive(Default, Debug)]
225pub struct UnionSchemaBuilder {
226    schemas: Vec<Schema>,
227    names: HashMap<Name, usize>,
228    variant_index: BTreeMap<SchemaKind, usize>,
229}
230
231impl UnionSchemaBuilder {
232    /// Create a builder.
233    ///
234    /// See also [`UnionSchema::builder`].
235    pub fn new() -> Self {
236        Self::default()
237    }
238
239    #[doc(hidden)]
240    /// This is not a public API, it should only be used by `avro_derive`
241    ///
242    /// Add a variant to this union, if it already exists ignore it.
243    ///
244    /// # Errors
245    /// Will return a [`Details::GetUnionDuplicateMap`] or [`Details::GetUnionDuplicateArray`] if
246    /// duplicate maps or arrays are encountered with different subtypes.
247    pub fn variant_ignore_duplicates(&mut self, schema: Schema) -> Result<&mut Self, Error> {
248        if let Some(name) = schema.name() {
249            if let Some(current) = self.names.get(name).copied() {
250                if self.schemas[current] != schema {
251                    return Err(Details::GetUnionDuplicateNamedSchemas(name.to_string()).into());
252                }
253            } else {
254                self.names.insert(name.clone(), self.schemas.len());
255                self.schemas.push(schema);
256            }
257        } else if let Schema::Map(_) = &schema {
258            if let Some(index) = self.variant_index.get(&SchemaKind::Map).copied() {
259                if self.schemas[index] != schema {
260                    return Err(
261                        Details::GetUnionDuplicateMap(self.schemas[index].clone(), schema).into(),
262                    );
263                }
264            } else {
265                self.variant_index
266                    .insert(SchemaKind::Map, self.schemas.len());
267                self.schemas.push(schema);
268            }
269        } else if let Schema::Array(_) = &schema {
270            if let Some(index) = self.variant_index.get(&SchemaKind::Array).copied() {
271                if self.schemas[index] != schema {
272                    return Err(Details::GetUnionDuplicateArray(
273                        self.schemas[index].clone(),
274                        schema,
275                    )
276                    .into());
277                }
278            } else {
279                self.variant_index
280                    .insert(SchemaKind::Array, self.schemas.len());
281                self.schemas.push(schema);
282            }
283        } else {
284            let discriminant = schema_to_base_schemakind(&schema);
285            if discriminant == SchemaKind::Union {
286                return Err(Details::GetNestedUnion.into());
287            }
288            if !self.variant_index.contains_key(&discriminant) {
289                self.variant_index.insert(discriminant, self.schemas.len());
290                self.schemas.push(schema);
291            }
292        }
293        Ok(self)
294    }
295
296    /// Add a variant to this union.
297    ///
298    /// # Errors
299    /// Will return a [`Details::GetUnionDuplicateNamedSchemas`] or [`Details::GetUnionDuplicate`] if
300    /// duplicate names or schema kinds are found.
301    pub fn variant(&mut self, schema: Schema) -> Result<&mut Self, Error> {
302        if let Some(name) = schema.name() {
303            if self.names.contains_key(name) {
304                return Err(Details::GetUnionDuplicateNamedSchemas(name.to_string()).into());
305            } else {
306                self.names.insert(name.clone(), self.schemas.len());
307                self.schemas.push(schema);
308            }
309        } else {
310            let discriminant = schema_to_base_schemakind(&schema);
311            if discriminant == SchemaKind::Union {
312                return Err(Details::GetNestedUnion.into());
313            }
314            if self.variant_index.contains_key(&discriminant) {
315                return Err(Details::GetUnionDuplicate(discriminant).into());
316            } else {
317                self.variant_index.insert(discriminant, self.schemas.len());
318                self.schemas.push(schema);
319            }
320        }
321        Ok(self)
322    }
323
324    /// Check if a schema already exists in this union.
325    pub fn contains(&self, schema: &Schema) -> bool {
326        if let Some(name) = schema.name() {
327            if let Some(current) = self.names.get(name).copied() {
328                &self.schemas[current] == schema
329            } else {
330                false
331            }
332        } else {
333            let discriminant = schema_to_base_schemakind(schema);
334            if let Some(index) = self.variant_index.get(&discriminant).copied() {
335                &self.schemas[index] == schema
336            } else {
337                false
338            }
339        }
340    }
341
342    /// Create the `UnionSchema`.
343    pub fn build(mut self) -> UnionSchema {
344        self.schemas.shrink_to_fit();
345        let mut named_index: Vec<_> = self.names.into_values().collect();
346        named_index.sort();
347        UnionSchema {
348            variant_index: self.variant_index,
349            named_index,
350            schemas: self.schemas,
351        }
352    }
353}
354
355/// Get the [`SchemaKind`] of a [`Schema`] converting logical types to their base type.
356fn schema_to_base_schemakind(schema: &Schema) -> SchemaKind {
357    let kind = schema.discriminant();
358    match kind {
359        SchemaKind::Date | SchemaKind::TimeMillis => SchemaKind::Int,
360        SchemaKind::TimeMicros
361        | SchemaKind::TimestampMillis
362        | SchemaKind::TimestampMicros
363        | SchemaKind::TimestampNanos
364        | SchemaKind::LocalTimestampMillis
365        | SchemaKind::LocalTimestampMicros
366        | SchemaKind::LocalTimestampNanos => SchemaKind::Long,
367        SchemaKind::Uuid => match schema {
368            Schema::Uuid(UuidSchema::Bytes) => SchemaKind::Bytes,
369            Schema::Uuid(UuidSchema::String) => SchemaKind::String,
370            Schema::Uuid(UuidSchema::Fixed(_)) => SchemaKind::Fixed,
371            _ => unreachable!(),
372        },
373        SchemaKind::Decimal => match schema {
374            Schema::Decimal(DecimalSchema {
375                inner: InnerDecimalSchema::Bytes,
376                ..
377            }) => SchemaKind::Bytes,
378            Schema::Decimal(DecimalSchema {
379                inner: InnerDecimalSchema::Fixed(_),
380                ..
381            }) => SchemaKind::Fixed,
382            _ => unreachable!(),
383        },
384        SchemaKind::Duration => SchemaKind::Fixed,
385        _ => kind,
386    }
387}
388
389#[cfg(test)]
390mod tests {
391    use super::*;
392    use crate::error::{Details, Error};
393    use crate::schema::RecordSchema;
394    use crate::types::Value;
395    use apache_avro_test_helper::TestResult;
396
397    #[test]
398    fn avro_rs_402_new_union_schema() -> TestResult {
399        let schema1 = Schema::Int;
400        let schema2 = Schema::String;
401        let union_schema = UnionSchema::new(vec![schema1.clone(), schema2.clone()])?;
402
403        assert_eq!(union_schema.variants(), &[schema1, schema2]);
404
405        Ok(())
406    }
407
408    #[test]
409    fn avro_rs_402_new_union_schema_duplicate_names() -> TestResult {
410        let res = UnionSchema::new(vec![
411            Schema::Record(RecordSchema::builder().try_name("Same_name")?.build()),
412            Schema::Record(RecordSchema::builder().try_name("Same_name")?.build()),
413        ])
414        .map_err(Error::into_details);
415
416        match res {
417            Err(Details::GetUnionDuplicateNamedSchemas(name)) => {
418                assert_eq!(name, Name::new("Same_name")?.to_string());
419            }
420            err => panic!("Expected GetUnionDuplicateNamedSchemas error, got: {err:?}"),
421        }
422
423        Ok(())
424    }
425
426    #[test]
427    fn avro_rs_489_union_schema_builder_primitive_type() -> TestResult {
428        let mut builder = UnionSchema::builder();
429        builder.variant(Schema::Null)?;
430        assert!(builder.variant(Schema::Null).is_err());
431        builder.variant_ignore_duplicates(Schema::Null)?;
432        builder.variant(Schema::Int)?;
433        assert!(builder.variant(Schema::Int).is_err());
434        builder.variant_ignore_duplicates(Schema::Int)?;
435        builder.variant(Schema::Long)?;
436        assert!(builder.variant(Schema::Long).is_err());
437        builder.variant_ignore_duplicates(Schema::Long)?;
438
439        let union = builder.build();
440        assert_eq!(union.schemas, &[Schema::Null, Schema::Int, Schema::Long]);
441
442        Ok(())
443    }
444
445    #[test]
446    fn avro_rs_489_union_schema_builder_complex_types() -> TestResult {
447        let enum_abc = Schema::parse_str(
448            r#"{
449            "type": "enum",
450            "name": "ABC",
451            "symbols": ["A", "B", "C"]
452        }"#,
453        )?;
454        let enum_abc_with_extra_symbol = Schema::parse_str(
455            r#"{
456            "type": "enum",
457            "name": "ABC",
458            "symbols": ["A", "B", "C", "D"]
459        }"#,
460        )?;
461        let enum_def = Schema::parse_str(
462            r#"{
463            "type": "enum",
464            "name": "DEF",
465            "symbols": ["D", "E", "F"]
466        }"#,
467        )?;
468        let fixed_abc = Schema::parse_str(
469            r#"{
470            "type": "fixed",
471            "name": "ABC",
472            "size": 1
473        }"#,
474        )?;
475        let fixed_foo = Schema::parse_str(
476            r#"{
477            "type": "fixed",
478            "name": "Foo",
479            "size": 1
480        }"#,
481        )?;
482
483        let mut builder = UnionSchema::builder();
484        builder.variant(enum_abc.clone())?;
485        assert!(builder.variant(enum_abc.clone()).is_err());
486        builder.variant_ignore_duplicates(enum_abc.clone())?;
487        // Name is the same but different schemas, so should always fail
488        assert!(builder.variant(fixed_abc.clone()).is_err());
489        assert!(
490            builder
491                .variant_ignore_duplicates(fixed_abc.clone())
492                .is_err()
493        );
494        // Name and schema type are the same but symbols are different
495        assert!(builder.variant(enum_abc_with_extra_symbol.clone()).is_err());
496        assert!(
497            builder
498                .variant_ignore_duplicates(enum_abc_with_extra_symbol.clone())
499                .is_err()
500        );
501        builder.variant(enum_def.clone())?;
502        assert!(builder.variant(enum_def.clone()).is_err());
503        builder.variant_ignore_duplicates(enum_def.clone())?;
504        builder.variant(fixed_foo.clone())?;
505        assert!(builder.variant(fixed_foo.clone()).is_err());
506        builder.variant_ignore_duplicates(fixed_foo.clone())?;
507
508        let union = builder.build();
509        assert_eq!(union.variants(), &[enum_abc, enum_def, fixed_foo]);
510
511        Ok(())
512    }
513
514    #[test]
515    fn avro_rs_489_union_schema_builder_logical_types() -> TestResult {
516        let fixed_uuid = Schema::parse_str(
517            r#"{
518            "type": "fixed",
519            "name": "Uuid",
520            "size": 16
521        }"#,
522        )?;
523        let uuid = Schema::parse_str(
524            r#"{
525            "type": "fixed",
526            "logicalType": "uuid",
527            "name": "Uuid",
528            "size": 16
529        }"#,
530        )?;
531
532        let mut builder = UnionSchema::builder();
533
534        builder.variant(Schema::Date)?;
535        assert!(builder.variant(Schema::Date).is_err());
536        builder.variant_ignore_duplicates(Schema::Date)?;
537        assert!(builder.variant(Schema::Int).is_err());
538        builder.variant_ignore_duplicates(Schema::Int)?;
539        builder.variant(uuid.clone())?;
540        assert!(builder.variant(uuid.clone()).is_err());
541        builder.variant_ignore_duplicates(uuid.clone())?;
542        assert!(builder.variant(fixed_uuid.clone()).is_err());
543        assert!(
544            builder
545                .variant_ignore_duplicates(fixed_uuid.clone())
546                .is_err()
547        );
548
549        let union = builder.build();
550        assert_eq!(union.schemas, &[Schema::Date, uuid]);
551
552        Ok(())
553    }
554
555    #[test]
556    fn avro_rs_489_find_schema_with_known_schemata_wrong_map() -> TestResult {
557        let union = UnionSchema::new(vec![Schema::map(Schema::Int).build(), Schema::Null])?;
558        let value = Value::Map(
559            [("key".to_string(), Value::String("value".to_string()))]
560                .into_iter()
561                .collect(),
562        );
563
564        assert!(
565            union
566                .find_schema_with_known_schemata(&value, None::<&HashMap<Name, Schema>>, None)
567                .is_none()
568        );
569
570        Ok(())
571    }
572
573    #[test]
574    fn avro_rs_489_find_schema_with_known_schemata_type_promotion() -> TestResult {
575        let union = UnionSchema::new(vec![Schema::Long, Schema::Null])?;
576        let value = Value::Int(42);
577
578        assert_eq!(
579            union.find_schema_with_known_schemata(&value, None::<&HashMap<Name, Schema>>, None),
580            Some((0, &Schema::Long))
581        );
582
583        Ok(())
584    }
585
586    #[test]
587    fn avro_rs_489_find_schema_with_known_schemata_uuid_vs_fixed() -> TestResult {
588        let uuid = Schema::parse_str(
589            r#"{
590            "type": "fixed",
591            "logicalType": "uuid",
592            "name": "Uuid",
593            "size": 16
594        }"#,
595        )?;
596        let union = UnionSchema::new(vec![uuid.clone(), Schema::Null])?;
597        let value = Value::Fixed(16, vec![0; 16]);
598
599        assert_eq!(
600            union.find_schema_with_known_schemata(&value, None::<&HashMap<Name, Schema>>, None),
601            Some((0, &uuid))
602        );
603
604        Ok(())
605    }
606}