1use std::{
19 borrow::Borrow,
20 collections::{BTreeMap, HashMap},
21 fmt::{Debug, Formatter},
22};
23
24use strum::IntoDiscriminant;
25
26use crate::{
27 AvroResult, Error,
28 error::Details,
29 schema::{
30 DecimalSchema, FixedSchema, InnerDecimalSchema, Name, NamespaceRef, RecordSchema, Schema,
31 SchemaKind, UuidSchema,
32 },
33 types,
34};
35
36#[derive(Clone)]
38pub struct UnionSchema {
39 pub(crate) schemas: Vec<Schema>,
41 variant_index: BTreeMap<SchemaKind, usize>,
45 named_index: Vec<usize>,
49}
50
51impl Debug for UnionSchema {
52 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
53 f.debug_struct("UnionSchema")
55 .field("schemas", &self.schemas)
56 .finish()
57 }
58}
59
60impl UnionSchema {
61 pub fn new(schemas: Vec<Schema>) -> AvroResult<Self> {
67 let mut builder = Self::builder();
68 for schema in schemas {
69 builder.variant(schema)?;
70 }
71 Ok(builder.build())
72 }
73
74 pub fn builder() -> UnionSchemaBuilder {
76 UnionSchemaBuilder::new()
77 }
78
79 pub fn variants(&self) -> &[Schema] {
81 &self.schemas
82 }
83
84 pub fn get_variant(&self, index: usize) -> Result<&Schema, Error> {
86 self.schemas.get(index).ok_or_else(|| {
87 Details::GetUnionVariant {
88 index: index as i64,
89 num_variants: self.schemas.len(),
90 }
91 .into()
92 })
93 }
94
95 pub(crate) fn index_of_schema_kind(&self, kind: SchemaKind) -> Option<usize> {
99 self.variant_index.get(&kind).copied()
100 }
101
102 pub(crate) fn find_named_schema<'s>(
108 &'s self,
109 name: &str,
110 names: &'s HashMap<Name, impl Borrow<Schema>>,
111 ) -> Result<Option<(usize, &'s Schema)>, Error> {
112 for index in self.named_index.iter().copied() {
113 let schema = &self.schemas[index];
114 if let Some(schema_name) = schema.name()
115 && schema_name.name() == name
116 {
117 let schema = if let Schema::Ref { name } = schema {
118 names
119 .get(name)
120 .ok_or_else(|| Details::SchemaResolutionError(name.clone()))?
121 .borrow()
122 } else {
123 schema
124 };
125 return Ok(Some((index, schema)));
126 }
127 }
128 Ok(None)
129 }
130
131 pub(crate) fn find_fully_qualified_named_schema<'s>(
135 &'s self,
136 full_name: &str,
137 names: &'s HashMap<Name, impl Borrow<Schema>>,
138 ) -> Result<Option<(usize, &'s Schema)>, Error> {
139 for index in self.named_index.iter().copied() {
140 let schema = &self.schemas[index];
141 if let Some(schema_name) = schema.name()
142 && schema_name.as_ref() == full_name
143 {
144 let schema = if let Schema::Ref { name } = schema {
145 names
146 .get(name)
147 .ok_or_else(|| Details::SchemaResolutionError(name.clone()))?
148 .borrow()
149 } else {
150 schema
151 };
152 return Ok(Some((index, schema)));
153 }
154 }
155 Ok(None)
156 }
157
158 pub(crate) fn find_fixed_of_size_n<'s>(
162 &'s self,
163 size: usize,
164 names: &'s HashMap<Name, impl Borrow<Schema>>,
165 ) -> Result<Option<(usize, &'s FixedSchema)>, Error> {
166 for index in self.named_index.iter().copied() {
167 let schema = &self.schemas[index];
168 let schema = if let Schema::Ref { name } = schema {
169 names
170 .get(name)
171 .ok_or_else(|| Details::SchemaResolutionError(name.clone()))?
172 .borrow()
173 } else {
174 schema
175 };
176 match schema {
177 Schema::Fixed(fixed)
178 | Schema::Uuid(UuidSchema::Fixed(fixed))
179 | Schema::Decimal(DecimalSchema {
180 inner: InnerDecimalSchema::Fixed(fixed),
181 ..
182 })
183 | Schema::Duration(fixed)
184 if fixed.size == size =>
185 {
186 return Ok(Some((index, fixed)));
187 }
188 _ => {}
189 }
190 }
191 Ok(None)
192 }
193
194 pub(crate) fn find_record_with_n_fields<'s>(
198 &'s self,
199 n_fields: usize,
200 names: &'s HashMap<Name, impl Borrow<Schema>>,
201 ) -> Result<Option<(usize, &'s RecordSchema)>, Error> {
202 for index in self.named_index.iter().copied() {
203 let schema = &self.schemas[index];
204 let schema = if let Schema::Ref { name } = schema {
205 names
206 .get(name)
207 .ok_or_else(|| Details::SchemaResolutionError(name.clone()))?
208 .borrow()
209 } else {
210 schema
211 };
212 match schema {
213 Schema::Record(record) if record.fields.len() == n_fields => {
214 return Ok(Some((index, record)));
215 }
216 _ => {}
217 }
218 }
219 Ok(None)
220 }
221
222 pub fn is_nullable(&self) -> bool {
224 self.variant_index.contains_key(&SchemaKind::Null)
225 }
226
227 pub fn find_schema_with_known_schemata<S: Borrow<Schema> + Debug>(
233 &self,
234 value: &types::Value,
235 known_schemata: Option<&HashMap<Name, S>>,
236 enclosing_namespace: NamespaceRef,
237 ) -> Option<(usize, &Schema)> {
238 let known_schemata_if_none = HashMap::new();
239 let known_schemata = known_schemata.unwrap_or(&known_schemata_if_none);
240 let ValueSchemaKind { unnamed, named } = Self::value_to_base_schemakind(value);
241 let unnamed = unnamed
243 .and_then(|kind| self.variant_index.get(&kind).copied())
244 .map(|index| (index, &self.schemas[index]))
245 .and_then(|(index, schema)| {
246 let kind = schema.discriminant();
247 if kind == SchemaKind::Map || kind == SchemaKind::Array {
249 let namespace = schema.namespace().or(enclosing_namespace);
250
251 value
253 .clone()
254 .resolve_internal(schema, known_schemata, namespace, &None)
255 .ok()
256 .map(|_| (index, schema))
257 } else {
258 Some((index, schema))
259 }
260 });
261 let named = named.and_then(|kind| {
262 self.named_index
265 .iter()
266 .copied()
267 .map(|i| (i, &self.schemas[i]))
268 .filter(|(_i, s)| {
269 let s_kind = schema_to_base_schemakind(s);
270 s_kind == kind || s_kind == SchemaKind::Ref
271 })
272 .find(|(_i, schema)| {
273 let namespace = schema.namespace().or(enclosing_namespace);
274
275 value
277 .clone()
278 .resolve_internal(schema, known_schemata, namespace, &None)
279 .is_ok()
280 })
281 });
282
283 match (unnamed, named) {
284 (Some((u_i, _)), Some((n_i, _))) if u_i < n_i => unnamed,
285 (Some(_), Some(_)) => named,
286 (Some(_), None) => unnamed,
287 (None, Some(_)) => named,
288 (None, None) => {
289 self.schemas.iter().enumerate().find(|(_i, schema)| {
291 let namespace = schema.namespace().or(enclosing_namespace);
292
293 value
295 .clone()
296 .resolve_internal(schema, known_schemata, namespace, &None)
297 .is_ok()
298 })
299 }
300 }
301 }
302
303 fn value_to_base_schemakind(value: &types::Value) -> ValueSchemaKind {
305 let schemakind = SchemaKind::from(value);
306 match schemakind {
307 SchemaKind::Decimal => ValueSchemaKind {
308 unnamed: Some(SchemaKind::Bytes),
309 named: Some(SchemaKind::Fixed),
310 },
311 SchemaKind::BigDecimal => ValueSchemaKind {
312 unnamed: Some(SchemaKind::Bytes),
313 named: None,
314 },
315 SchemaKind::Uuid => ValueSchemaKind {
316 unnamed: Some(SchemaKind::String),
317 named: Some(SchemaKind::Fixed),
318 },
319 SchemaKind::Date | SchemaKind::TimeMillis => ValueSchemaKind {
320 unnamed: Some(SchemaKind::Int),
321 named: None,
322 },
323 SchemaKind::TimeMicros
324 | SchemaKind::TimestampMillis
325 | SchemaKind::TimestampMicros
326 | SchemaKind::TimestampNanos
327 | SchemaKind::LocalTimestampMillis
328 | SchemaKind::LocalTimestampMicros
329 | SchemaKind::LocalTimestampNanos => ValueSchemaKind {
330 unnamed: Some(SchemaKind::Long),
331 named: None,
332 },
333 SchemaKind::Duration => ValueSchemaKind {
334 unnamed: None,
335 named: Some(SchemaKind::Fixed),
336 },
337 SchemaKind::Record | SchemaKind::Enum | SchemaKind::Fixed => ValueSchemaKind {
338 unnamed: None,
339 named: Some(schemakind),
340 },
341 SchemaKind::Map => ValueSchemaKind {
344 unnamed: Some(SchemaKind::Map),
345 named: Some(SchemaKind::Record),
346 },
347 _ => ValueSchemaKind {
348 unnamed: Some(schemakind),
349 named: None,
350 },
351 }
352 }
353}
354
355struct ValueSchemaKind {
357 unnamed: Option<SchemaKind>,
358 named: Option<SchemaKind>,
359}
360
361impl PartialEq for UnionSchema {
363 fn eq(&self, other: &UnionSchema) -> bool {
364 self.schemas.eq(&other.schemas)
365 }
366}
367
368#[derive(Default, Debug)]
370pub struct UnionSchemaBuilder {
371 schemas: Vec<Schema>,
372 names: HashMap<Name, usize>,
373 variant_index: BTreeMap<SchemaKind, usize>,
374}
375
376impl UnionSchemaBuilder {
377 pub fn new() -> Self {
381 Self::default()
382 }
383
384 #[doc(hidden)]
385 pub fn variant_ignore_duplicates(&mut self, schema: Schema) -> Result<&mut Self, Error> {
393 if let Some(name) = schema.name() {
394 if let Some(current) = self.names.get(name).copied() {
395 if self.schemas[current] != schema {
396 return Err(Details::GetUnionDuplicateNamedSchemas(name.to_string()).into());
397 }
398 } else {
399 self.names.insert(name.clone(), self.schemas.len());
400 self.schemas.push(schema);
401 }
402 } else if let Schema::Map(_) = &schema {
403 if let Some(index) = self.variant_index.get(&SchemaKind::Map).copied() {
404 if self.schemas[index] != schema {
405 return Err(
406 Details::GetUnionDuplicateMap(self.schemas[index].clone(), schema).into(),
407 );
408 }
409 } else {
410 self.variant_index
411 .insert(SchemaKind::Map, self.schemas.len());
412 self.schemas.push(schema);
413 }
414 } else if let Schema::Array(_) = &schema {
415 if let Some(index) = self.variant_index.get(&SchemaKind::Array).copied() {
416 if self.schemas[index] != schema {
417 return Err(Details::GetUnionDuplicateArray(
418 self.schemas[index].clone(),
419 schema,
420 )
421 .into());
422 }
423 } else {
424 self.variant_index
425 .insert(SchemaKind::Array, self.schemas.len());
426 self.schemas.push(schema);
427 }
428 } else {
429 let discriminant = schema_to_base_schemakind(&schema);
430 if discriminant == SchemaKind::Union {
431 return Err(Details::GetNestedUnion.into());
432 }
433 if !self.variant_index.contains_key(&discriminant) {
434 self.variant_index.insert(discriminant, self.schemas.len());
435 self.schemas.push(schema);
436 }
437 }
438 Ok(self)
439 }
440
441 pub fn variant(&mut self, schema: Schema) -> Result<&mut Self, Error> {
447 if let Some(name) = schema.name() {
448 if self.names.contains_key(name) {
449 return Err(Details::GetUnionDuplicateNamedSchemas(name.to_string()).into());
450 } else {
451 self.names.insert(name.clone(), self.schemas.len());
452 self.schemas.push(schema);
453 }
454 } else {
455 let discriminant = schema_to_base_schemakind(&schema);
456 if discriminant == SchemaKind::Union {
457 return Err(Details::GetNestedUnion.into());
458 }
459 if self.variant_index.contains_key(&discriminant) {
460 return Err(Details::GetUnionDuplicate(discriminant).into());
461 } else {
462 self.variant_index.insert(discriminant, self.schemas.len());
463 self.schemas.push(schema);
464 }
465 }
466 Ok(self)
467 }
468
469 pub fn contains(&self, schema: &Schema) -> bool {
471 if let Some(name) = schema.name() {
472 if let Some(current) = self.names.get(name).copied() {
473 &self.schemas[current] == schema
474 } else {
475 false
476 }
477 } else {
478 let discriminant = schema_to_base_schemakind(schema);
479 if let Some(index) = self.variant_index.get(&discriminant).copied() {
480 &self.schemas[index] == schema
481 } else {
482 false
483 }
484 }
485 }
486
487 pub fn build(mut self) -> UnionSchema {
489 self.schemas.shrink_to_fit();
490 let mut named_index: Vec<_> = self.names.into_values().collect();
491 named_index.sort();
492 UnionSchema {
493 variant_index: self.variant_index,
494 named_index,
495 schemas: self.schemas,
496 }
497 }
498}
499
500fn schema_to_base_schemakind(schema: &Schema) -> SchemaKind {
502 let kind = schema.discriminant();
503 match kind {
504 SchemaKind::Date | SchemaKind::TimeMillis => SchemaKind::Int,
505 SchemaKind::TimeMicros
506 | SchemaKind::TimestampMillis
507 | SchemaKind::TimestampMicros
508 | SchemaKind::TimestampNanos
509 | SchemaKind::LocalTimestampMillis
510 | SchemaKind::LocalTimestampMicros
511 | SchemaKind::LocalTimestampNanos => SchemaKind::Long,
512 SchemaKind::Uuid => match schema {
513 Schema::Uuid(UuidSchema::Bytes) => SchemaKind::Bytes,
514 Schema::Uuid(UuidSchema::String) => SchemaKind::String,
515 Schema::Uuid(UuidSchema::Fixed(_)) => SchemaKind::Fixed,
516 _ => unreachable!(),
517 },
518 SchemaKind::Decimal => match schema {
519 Schema::Decimal(DecimalSchema {
520 inner: InnerDecimalSchema::Bytes,
521 ..
522 }) => SchemaKind::Bytes,
523 Schema::Decimal(DecimalSchema {
524 inner: InnerDecimalSchema::Fixed(_),
525 ..
526 }) => SchemaKind::Fixed,
527 _ => unreachable!(),
528 },
529 SchemaKind::Duration => SchemaKind::Fixed,
530 _ => kind,
531 }
532}
533
534#[cfg(test)]
535mod tests {
536 use super::*;
537 use crate::error::{Details, Error};
538 use crate::schema::RecordSchema;
539 use crate::types::Value;
540 use apache_avro_test_helper::TestResult;
541
542 #[test]
543 fn avro_rs_402_new_union_schema() -> TestResult {
544 let schema1 = Schema::Int;
545 let schema2 = Schema::String;
546 let union_schema = UnionSchema::new(vec![schema1.clone(), schema2.clone()])?;
547
548 assert_eq!(union_schema.variants(), &[schema1, schema2]);
549
550 Ok(())
551 }
552
553 #[test]
554 fn avro_rs_402_new_union_schema_duplicate_names() -> TestResult {
555 let res = UnionSchema::new(vec![
556 Schema::Record(RecordSchema::builder().try_name("Same_name")?.build()),
557 Schema::Record(RecordSchema::builder().try_name("Same_name")?.build()),
558 ])
559 .map_err(Error::into_details);
560
561 match res {
562 Err(Details::GetUnionDuplicateNamedSchemas(name)) => {
563 assert_eq!(name, Name::new("Same_name")?.to_string());
564 }
565 err => panic!("Expected GetUnionDuplicateNamedSchemas error, got: {err:?}"),
566 }
567
568 Ok(())
569 }
570
571 #[test]
572 fn avro_rs_489_union_schema_builder_primitive_type() -> TestResult {
573 let mut builder = UnionSchema::builder();
574 builder.variant(Schema::Null)?;
575 assert!(builder.variant(Schema::Null).is_err());
576 builder.variant_ignore_duplicates(Schema::Null)?;
577 builder.variant(Schema::Int)?;
578 assert!(builder.variant(Schema::Int).is_err());
579 builder.variant_ignore_duplicates(Schema::Int)?;
580 builder.variant(Schema::Long)?;
581 assert!(builder.variant(Schema::Long).is_err());
582 builder.variant_ignore_duplicates(Schema::Long)?;
583
584 let union = builder.build();
585 assert_eq!(union.schemas, &[Schema::Null, Schema::Int, Schema::Long]);
586
587 Ok(())
588 }
589
590 #[test]
591 fn avro_rs_489_union_schema_builder_complex_types() -> TestResult {
592 let enum_abc = Schema::parse_str(
593 r#"{
594 "type": "enum",
595 "name": "ABC",
596 "symbols": ["A", "B", "C"]
597 }"#,
598 )?;
599 let enum_abc_with_extra_symbol = Schema::parse_str(
600 r#"{
601 "type": "enum",
602 "name": "ABC",
603 "symbols": ["A", "B", "C", "D"]
604 }"#,
605 )?;
606 let enum_def = Schema::parse_str(
607 r#"{
608 "type": "enum",
609 "name": "DEF",
610 "symbols": ["D", "E", "F"]
611 }"#,
612 )?;
613 let fixed_abc = Schema::parse_str(
614 r#"{
615 "type": "fixed",
616 "name": "ABC",
617 "size": 1
618 }"#,
619 )?;
620 let fixed_foo = Schema::parse_str(
621 r#"{
622 "type": "fixed",
623 "name": "Foo",
624 "size": 1
625 }"#,
626 )?;
627
628 let mut builder = UnionSchema::builder();
629 builder.variant(enum_abc.clone())?;
630 assert!(builder.variant(enum_abc.clone()).is_err());
631 builder.variant_ignore_duplicates(enum_abc.clone())?;
632 assert!(builder.variant(fixed_abc.clone()).is_err());
634 assert!(
635 builder
636 .variant_ignore_duplicates(fixed_abc.clone())
637 .is_err()
638 );
639 assert!(builder.variant(enum_abc_with_extra_symbol.clone()).is_err());
641 assert!(
642 builder
643 .variant_ignore_duplicates(enum_abc_with_extra_symbol.clone())
644 .is_err()
645 );
646 builder.variant(enum_def.clone())?;
647 assert!(builder.variant(enum_def.clone()).is_err());
648 builder.variant_ignore_duplicates(enum_def.clone())?;
649 builder.variant(fixed_foo.clone())?;
650 assert!(builder.variant(fixed_foo.clone()).is_err());
651 builder.variant_ignore_duplicates(fixed_foo.clone())?;
652
653 let union = builder.build();
654 assert_eq!(union.variants(), &[enum_abc, enum_def, fixed_foo]);
655
656 Ok(())
657 }
658
659 #[test]
660 fn avro_rs_489_union_schema_builder_logical_types() -> TestResult {
661 let fixed_uuid = Schema::parse_str(
662 r#"{
663 "type": "fixed",
664 "name": "Uuid",
665 "size": 16
666 }"#,
667 )?;
668 let uuid = Schema::parse_str(
669 r#"{
670 "type": "fixed",
671 "logicalType": "uuid",
672 "name": "Uuid",
673 "size": 16
674 }"#,
675 )?;
676
677 let mut builder = UnionSchema::builder();
678
679 builder.variant(Schema::Date)?;
680 assert!(builder.variant(Schema::Date).is_err());
681 builder.variant_ignore_duplicates(Schema::Date)?;
682 assert!(builder.variant(Schema::Int).is_err());
683 builder.variant_ignore_duplicates(Schema::Int)?;
684 builder.variant(uuid.clone())?;
685 assert!(builder.variant(uuid.clone()).is_err());
686 builder.variant_ignore_duplicates(uuid.clone())?;
687 assert!(builder.variant(fixed_uuid.clone()).is_err());
688 assert!(
689 builder
690 .variant_ignore_duplicates(fixed_uuid.clone())
691 .is_err()
692 );
693
694 let union = builder.build();
695 assert_eq!(union.schemas, &[Schema::Date, uuid]);
696
697 Ok(())
698 }
699
700 #[test]
701 fn avro_rs_489_find_schema_with_known_schemata_wrong_map() -> TestResult {
702 let union = UnionSchema::new(vec![Schema::map(Schema::Int).build(), Schema::Null])?;
703 let value = Value::Map(
704 [("key".to_string(), Value::String("value".to_string()))]
705 .into_iter()
706 .collect(),
707 );
708
709 assert!(
710 union
711 .find_schema_with_known_schemata(&value, None::<&HashMap<Name, Schema>>, None)
712 .is_none()
713 );
714
715 Ok(())
716 }
717
718 #[test]
719 fn avro_rs_489_find_schema_with_known_schemata_type_promotion() -> TestResult {
720 let union = UnionSchema::new(vec![Schema::Long, Schema::Null])?;
721 let value = Value::Int(42);
722
723 assert_eq!(
724 union.find_schema_with_known_schemata(&value, None::<&HashMap<Name, Schema>>, None),
725 Some((0, &Schema::Long))
726 );
727
728 Ok(())
729 }
730
731 #[test]
732 fn avro_rs_489_find_schema_with_known_schemata_uuid_vs_fixed() -> TestResult {
733 let uuid = Schema::parse_str(
734 r#"{
735 "type": "fixed",
736 "logicalType": "uuid",
737 "name": "Uuid",
738 "size": 16
739 }"#,
740 )?;
741 let union = UnionSchema::new(vec![uuid.clone(), Schema::Null])?;
742 let value = Value::Fixed(16, vec![0; 16]);
743
744 assert_eq!(
745 union.find_schema_with_known_schemata(&value, None::<&HashMap<Name, Schema>>, None),
746 Some((0, &uuid))
747 );
748
749 Ok(())
750 }
751}