1use std::{
19 borrow::Borrow,
20 collections::{BTreeMap, HashMap},
21 fmt::{Debug, Formatter},
22};
23
24use strum::IntoDiscriminant;
25
26use crate::{
27 AvroResult, Error,
28 error::Details,
29 schema::{
30 DecimalSchema, FixedSchema, InnerDecimalSchema, Name, NamespaceRef, RecordSchema, Schema,
31 SchemaKind, UuidSchema,
32 },
33 types,
34};
35
36#[derive(Clone)]
38pub struct UnionSchema {
39 pub(crate) schemas: Vec<Schema>,
41 variant_index: BTreeMap<SchemaKind, usize>,
45 named_index: Vec<usize>,
49}
50
51impl Debug for UnionSchema {
52 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
53 f.debug_struct("UnionSchema")
55 .field("schemas", &self.schemas)
56 .finish()
57 }
58}
59
60impl UnionSchema {
61 pub fn new(schemas: Vec<Schema>) -> AvroResult<Self> {
67 let mut builder = Self::builder();
68 for schema in schemas {
69 builder.variant(schema)?;
70 }
71 Ok(builder.build())
72 }
73
74 pub fn builder() -> UnionSchemaBuilder {
76 UnionSchemaBuilder::new()
77 }
78
79 pub fn variants(&self) -> &[Schema] {
81 &self.schemas
82 }
83
84 pub fn get_variant(&self, index: usize) -> Result<&Schema, Error> {
86 self.schemas.get(index).ok_or_else(|| {
87 Details::GetUnionVariant {
88 index: index as i64,
89 num_variants: self.schemas.len(),
90 }
91 .into()
92 })
93 }
94
95 pub(crate) fn index_of_schema_kind(&self, kind: SchemaKind) -> Option<usize> {
99 self.variant_index.get(&kind).copied()
100 }
101
102 pub(crate) fn find_named_schema<'s>(
108 &'s self,
109 name: &str,
110 names: &'s HashMap<Name, impl Borrow<Schema>>,
111 ) -> Result<Option<(usize, &'s Schema)>, Error> {
112 for index in self.named_index.iter().copied() {
113 let schema = &self.schemas[index];
114 if let Some(schema_name) = schema.name()
115 && schema_name.name() == name
116 {
117 let schema = if let Schema::Ref { name } = schema {
118 names
119 .get(name)
120 .ok_or_else(|| Details::SchemaResolutionError(name.clone()))?
121 .borrow()
122 } else {
123 schema
124 };
125 return Ok(Some((index, schema)));
126 }
127 }
128 Ok(None)
129 }
130
131 pub(crate) fn find_fixed_of_size_n<'s>(
135 &'s self,
136 size: usize,
137 names: &'s HashMap<Name, impl Borrow<Schema>>,
138 ) -> Result<Option<(usize, &'s FixedSchema)>, Error> {
139 for index in self.named_index.iter().copied() {
140 let schema = &self.schemas[index];
141 let schema = if let Schema::Ref { name } = schema {
142 names
143 .get(name)
144 .ok_or_else(|| Details::SchemaResolutionError(name.clone()))?
145 .borrow()
146 } else {
147 schema
148 };
149 match schema {
150 Schema::Fixed(fixed)
151 | Schema::Uuid(UuidSchema::Fixed(fixed))
152 | Schema::Decimal(DecimalSchema {
153 inner: InnerDecimalSchema::Fixed(fixed),
154 ..
155 })
156 | Schema::Duration(fixed)
157 if fixed.size == size =>
158 {
159 return Ok(Some((index, fixed)));
160 }
161 _ => {}
162 }
163 }
164 Ok(None)
165 }
166
167 pub(crate) fn find_record_with_n_fields<'s>(
171 &'s self,
172 n_fields: usize,
173 names: &'s HashMap<Name, impl Borrow<Schema>>,
174 ) -> Result<Option<(usize, &'s RecordSchema)>, Error> {
175 for index in self.named_index.iter().copied() {
176 let schema = &self.schemas[index];
177 let schema = if let Schema::Ref { name } = schema {
178 names
179 .get(name)
180 .ok_or_else(|| Details::SchemaResolutionError(name.clone()))?
181 .borrow()
182 } else {
183 schema
184 };
185 match schema {
186 Schema::Record(record) if record.fields.len() == n_fields => {
187 return Ok(Some((index, record)));
188 }
189 _ => {}
190 }
191 }
192 Ok(None)
193 }
194
195 pub fn is_nullable(&self) -> bool {
197 self.variant_index.contains_key(&SchemaKind::Null)
198 }
199
200 pub fn find_schema_with_known_schemata<S: Borrow<Schema> + Debug>(
206 &self,
207 value: &types::Value,
208 known_schemata: Option<&HashMap<Name, S>>,
209 enclosing_namespace: NamespaceRef,
210 ) -> Option<(usize, &Schema)> {
211 let known_schemata_if_none = HashMap::new();
212 let known_schemata = known_schemata.unwrap_or(&known_schemata_if_none);
213 let ValueSchemaKind { unnamed, named } = Self::value_to_base_schemakind(value);
214 let unnamed = unnamed
216 .and_then(|kind| self.variant_index.get(&kind).copied())
217 .map(|index| (index, &self.schemas[index]))
218 .and_then(|(index, schema)| {
219 let kind = schema.discriminant();
220 if kind == SchemaKind::Map || kind == SchemaKind::Array {
222 let namespace = schema.namespace().or(enclosing_namespace);
223
224 value
226 .clone()
227 .resolve_internal(schema, known_schemata, namespace, &None)
228 .ok()
229 .map(|_| (index, schema))
230 } else {
231 Some((index, schema))
232 }
233 });
234 let named = named.and_then(|kind| {
235 self.named_index
238 .iter()
239 .copied()
240 .map(|i| (i, &self.schemas[i]))
241 .filter(|(_i, s)| {
242 let s_kind = schema_to_base_schemakind(s);
243 s_kind == kind || s_kind == SchemaKind::Ref
244 })
245 .find(|(_i, schema)| {
246 let namespace = schema.namespace().or(enclosing_namespace);
247
248 value
250 .clone()
251 .resolve_internal(schema, known_schemata, namespace, &None)
252 .is_ok()
253 })
254 });
255
256 match (unnamed, named) {
257 (Some((u_i, _)), Some((n_i, _))) if u_i < n_i => unnamed,
258 (Some(_), Some(_)) => named,
259 (Some(_), None) => unnamed,
260 (None, Some(_)) => named,
261 (None, None) => {
262 self.schemas.iter().enumerate().find(|(_i, schema)| {
264 let namespace = schema.namespace().or(enclosing_namespace);
265
266 value
268 .clone()
269 .resolve_internal(schema, known_schemata, namespace, &None)
270 .is_ok()
271 })
272 }
273 }
274 }
275
276 fn value_to_base_schemakind(value: &types::Value) -> ValueSchemaKind {
278 let schemakind = SchemaKind::from(value);
279 match schemakind {
280 SchemaKind::Decimal => ValueSchemaKind {
281 unnamed: Some(SchemaKind::Bytes),
282 named: Some(SchemaKind::Fixed),
283 },
284 SchemaKind::BigDecimal => ValueSchemaKind {
285 unnamed: Some(SchemaKind::Bytes),
286 named: None,
287 },
288 SchemaKind::Uuid => ValueSchemaKind {
289 unnamed: Some(SchemaKind::String),
290 named: Some(SchemaKind::Fixed),
291 },
292 SchemaKind::Date | SchemaKind::TimeMillis => ValueSchemaKind {
293 unnamed: Some(SchemaKind::Int),
294 named: None,
295 },
296 SchemaKind::TimeMicros
297 | SchemaKind::TimestampMillis
298 | SchemaKind::TimestampMicros
299 | SchemaKind::TimestampNanos
300 | SchemaKind::LocalTimestampMillis
301 | SchemaKind::LocalTimestampMicros
302 | SchemaKind::LocalTimestampNanos => ValueSchemaKind {
303 unnamed: Some(SchemaKind::Long),
304 named: None,
305 },
306 SchemaKind::Duration => ValueSchemaKind {
307 unnamed: None,
308 named: Some(SchemaKind::Fixed),
309 },
310 SchemaKind::Record | SchemaKind::Enum | SchemaKind::Fixed => ValueSchemaKind {
311 unnamed: None,
312 named: Some(schemakind),
313 },
314 SchemaKind::Map => ValueSchemaKind {
317 unnamed: Some(SchemaKind::Map),
318 named: Some(SchemaKind::Record),
319 },
320 _ => ValueSchemaKind {
321 unnamed: Some(schemakind),
322 named: None,
323 },
324 }
325 }
326}
327
328struct ValueSchemaKind {
330 unnamed: Option<SchemaKind>,
331 named: Option<SchemaKind>,
332}
333
334impl PartialEq for UnionSchema {
336 fn eq(&self, other: &UnionSchema) -> bool {
337 self.schemas.eq(&other.schemas)
338 }
339}
340
341#[derive(Default, Debug)]
343pub struct UnionSchemaBuilder {
344 schemas: Vec<Schema>,
345 names: HashMap<Name, usize>,
346 variant_index: BTreeMap<SchemaKind, usize>,
347}
348
349impl UnionSchemaBuilder {
350 pub fn new() -> Self {
354 Self::default()
355 }
356
357 #[doc(hidden)]
358 pub fn variant_ignore_duplicates(&mut self, schema: Schema) -> Result<&mut Self, Error> {
366 if let Some(name) = schema.name() {
367 if let Some(current) = self.names.get(name).copied() {
368 if self.schemas[current] != schema {
369 return Err(Details::GetUnionDuplicateNamedSchemas(name.to_string()).into());
370 }
371 } else {
372 self.names.insert(name.clone(), self.schemas.len());
373 self.schemas.push(schema);
374 }
375 } else if let Schema::Map(_) = &schema {
376 if let Some(index) = self.variant_index.get(&SchemaKind::Map).copied() {
377 if self.schemas[index] != schema {
378 return Err(
379 Details::GetUnionDuplicateMap(self.schemas[index].clone(), schema).into(),
380 );
381 }
382 } else {
383 self.variant_index
384 .insert(SchemaKind::Map, self.schemas.len());
385 self.schemas.push(schema);
386 }
387 } else if let Schema::Array(_) = &schema {
388 if let Some(index) = self.variant_index.get(&SchemaKind::Array).copied() {
389 if self.schemas[index] != schema {
390 return Err(Details::GetUnionDuplicateArray(
391 self.schemas[index].clone(),
392 schema,
393 )
394 .into());
395 }
396 } else {
397 self.variant_index
398 .insert(SchemaKind::Array, self.schemas.len());
399 self.schemas.push(schema);
400 }
401 } else {
402 let discriminant = schema_to_base_schemakind(&schema);
403 if discriminant == SchemaKind::Union {
404 return Err(Details::GetNestedUnion.into());
405 }
406 if !self.variant_index.contains_key(&discriminant) {
407 self.variant_index.insert(discriminant, self.schemas.len());
408 self.schemas.push(schema);
409 }
410 }
411 Ok(self)
412 }
413
414 pub fn variant(&mut self, schema: Schema) -> Result<&mut Self, Error> {
420 if let Some(name) = schema.name() {
421 if self.names.contains_key(name) {
422 return Err(Details::GetUnionDuplicateNamedSchemas(name.to_string()).into());
423 } else {
424 self.names.insert(name.clone(), self.schemas.len());
425 self.schemas.push(schema);
426 }
427 } else {
428 let discriminant = schema_to_base_schemakind(&schema);
429 if discriminant == SchemaKind::Union {
430 return Err(Details::GetNestedUnion.into());
431 }
432 if self.variant_index.contains_key(&discriminant) {
433 return Err(Details::GetUnionDuplicate(discriminant).into());
434 } else {
435 self.variant_index.insert(discriminant, self.schemas.len());
436 self.schemas.push(schema);
437 }
438 }
439 Ok(self)
440 }
441
442 pub fn contains(&self, schema: &Schema) -> bool {
444 if let Some(name) = schema.name() {
445 if let Some(current) = self.names.get(name).copied() {
446 &self.schemas[current] == schema
447 } else {
448 false
449 }
450 } else {
451 let discriminant = schema_to_base_schemakind(schema);
452 if let Some(index) = self.variant_index.get(&discriminant).copied() {
453 &self.schemas[index] == schema
454 } else {
455 false
456 }
457 }
458 }
459
460 pub fn build(mut self) -> UnionSchema {
462 self.schemas.shrink_to_fit();
463 let mut named_index: Vec<_> = self.names.into_values().collect();
464 named_index.sort();
465 UnionSchema {
466 variant_index: self.variant_index,
467 named_index,
468 schemas: self.schemas,
469 }
470 }
471}
472
473fn schema_to_base_schemakind(schema: &Schema) -> SchemaKind {
475 let kind = schema.discriminant();
476 match kind {
477 SchemaKind::Date | SchemaKind::TimeMillis => SchemaKind::Int,
478 SchemaKind::TimeMicros
479 | SchemaKind::TimestampMillis
480 | SchemaKind::TimestampMicros
481 | SchemaKind::TimestampNanos
482 | SchemaKind::LocalTimestampMillis
483 | SchemaKind::LocalTimestampMicros
484 | SchemaKind::LocalTimestampNanos => SchemaKind::Long,
485 SchemaKind::Uuid => match schema {
486 Schema::Uuid(UuidSchema::Bytes) => SchemaKind::Bytes,
487 Schema::Uuid(UuidSchema::String) => SchemaKind::String,
488 Schema::Uuid(UuidSchema::Fixed(_)) => SchemaKind::Fixed,
489 _ => unreachable!(),
490 },
491 SchemaKind::Decimal => match schema {
492 Schema::Decimal(DecimalSchema {
493 inner: InnerDecimalSchema::Bytes,
494 ..
495 }) => SchemaKind::Bytes,
496 Schema::Decimal(DecimalSchema {
497 inner: InnerDecimalSchema::Fixed(_),
498 ..
499 }) => SchemaKind::Fixed,
500 _ => unreachable!(),
501 },
502 SchemaKind::Duration => SchemaKind::Fixed,
503 _ => kind,
504 }
505}
506
507#[cfg(test)]
508mod tests {
509 use super::*;
510 use crate::error::{Details, Error};
511 use crate::schema::RecordSchema;
512 use crate::types::Value;
513 use apache_avro_test_helper::TestResult;
514
515 #[test]
516 fn avro_rs_402_new_union_schema() -> TestResult {
517 let schema1 = Schema::Int;
518 let schema2 = Schema::String;
519 let union_schema = UnionSchema::new(vec![schema1.clone(), schema2.clone()])?;
520
521 assert_eq!(union_schema.variants(), &[schema1, schema2]);
522
523 Ok(())
524 }
525
526 #[test]
527 fn avro_rs_402_new_union_schema_duplicate_names() -> TestResult {
528 let res = UnionSchema::new(vec![
529 Schema::Record(RecordSchema::builder().try_name("Same_name")?.build()),
530 Schema::Record(RecordSchema::builder().try_name("Same_name")?.build()),
531 ])
532 .map_err(Error::into_details);
533
534 match res {
535 Err(Details::GetUnionDuplicateNamedSchemas(name)) => {
536 assert_eq!(name, Name::new("Same_name")?.to_string());
537 }
538 err => panic!("Expected GetUnionDuplicateNamedSchemas error, got: {err:?}"),
539 }
540
541 Ok(())
542 }
543
544 #[test]
545 fn avro_rs_489_union_schema_builder_primitive_type() -> TestResult {
546 let mut builder = UnionSchema::builder();
547 builder.variant(Schema::Null)?;
548 assert!(builder.variant(Schema::Null).is_err());
549 builder.variant_ignore_duplicates(Schema::Null)?;
550 builder.variant(Schema::Int)?;
551 assert!(builder.variant(Schema::Int).is_err());
552 builder.variant_ignore_duplicates(Schema::Int)?;
553 builder.variant(Schema::Long)?;
554 assert!(builder.variant(Schema::Long).is_err());
555 builder.variant_ignore_duplicates(Schema::Long)?;
556
557 let union = builder.build();
558 assert_eq!(union.schemas, &[Schema::Null, Schema::Int, Schema::Long]);
559
560 Ok(())
561 }
562
563 #[test]
564 fn avro_rs_489_union_schema_builder_complex_types() -> TestResult {
565 let enum_abc = Schema::parse_str(
566 r#"{
567 "type": "enum",
568 "name": "ABC",
569 "symbols": ["A", "B", "C"]
570 }"#,
571 )?;
572 let enum_abc_with_extra_symbol = Schema::parse_str(
573 r#"{
574 "type": "enum",
575 "name": "ABC",
576 "symbols": ["A", "B", "C", "D"]
577 }"#,
578 )?;
579 let enum_def = Schema::parse_str(
580 r#"{
581 "type": "enum",
582 "name": "DEF",
583 "symbols": ["D", "E", "F"]
584 }"#,
585 )?;
586 let fixed_abc = Schema::parse_str(
587 r#"{
588 "type": "fixed",
589 "name": "ABC",
590 "size": 1
591 }"#,
592 )?;
593 let fixed_foo = Schema::parse_str(
594 r#"{
595 "type": "fixed",
596 "name": "Foo",
597 "size": 1
598 }"#,
599 )?;
600
601 let mut builder = UnionSchema::builder();
602 builder.variant(enum_abc.clone())?;
603 assert!(builder.variant(enum_abc.clone()).is_err());
604 builder.variant_ignore_duplicates(enum_abc.clone())?;
605 assert!(builder.variant(fixed_abc.clone()).is_err());
607 assert!(
608 builder
609 .variant_ignore_duplicates(fixed_abc.clone())
610 .is_err()
611 );
612 assert!(builder.variant(enum_abc_with_extra_symbol.clone()).is_err());
614 assert!(
615 builder
616 .variant_ignore_duplicates(enum_abc_with_extra_symbol.clone())
617 .is_err()
618 );
619 builder.variant(enum_def.clone())?;
620 assert!(builder.variant(enum_def.clone()).is_err());
621 builder.variant_ignore_duplicates(enum_def.clone())?;
622 builder.variant(fixed_foo.clone())?;
623 assert!(builder.variant(fixed_foo.clone()).is_err());
624 builder.variant_ignore_duplicates(fixed_foo.clone())?;
625
626 let union = builder.build();
627 assert_eq!(union.variants(), &[enum_abc, enum_def, fixed_foo]);
628
629 Ok(())
630 }
631
632 #[test]
633 fn avro_rs_489_union_schema_builder_logical_types() -> TestResult {
634 let fixed_uuid = Schema::parse_str(
635 r#"{
636 "type": "fixed",
637 "name": "Uuid",
638 "size": 16
639 }"#,
640 )?;
641 let uuid = Schema::parse_str(
642 r#"{
643 "type": "fixed",
644 "logicalType": "uuid",
645 "name": "Uuid",
646 "size": 16
647 }"#,
648 )?;
649
650 let mut builder = UnionSchema::builder();
651
652 builder.variant(Schema::Date)?;
653 assert!(builder.variant(Schema::Date).is_err());
654 builder.variant_ignore_duplicates(Schema::Date)?;
655 assert!(builder.variant(Schema::Int).is_err());
656 builder.variant_ignore_duplicates(Schema::Int)?;
657 builder.variant(uuid.clone())?;
658 assert!(builder.variant(uuid.clone()).is_err());
659 builder.variant_ignore_duplicates(uuid.clone())?;
660 assert!(builder.variant(fixed_uuid.clone()).is_err());
661 assert!(
662 builder
663 .variant_ignore_duplicates(fixed_uuid.clone())
664 .is_err()
665 );
666
667 let union = builder.build();
668 assert_eq!(union.schemas, &[Schema::Date, uuid]);
669
670 Ok(())
671 }
672
673 #[test]
674 fn avro_rs_489_find_schema_with_known_schemata_wrong_map() -> TestResult {
675 let union = UnionSchema::new(vec![Schema::map(Schema::Int).build(), Schema::Null])?;
676 let value = Value::Map(
677 [("key".to_string(), Value::String("value".to_string()))]
678 .into_iter()
679 .collect(),
680 );
681
682 assert!(
683 union
684 .find_schema_with_known_schemata(&value, None::<&HashMap<Name, Schema>>, None)
685 .is_none()
686 );
687
688 Ok(())
689 }
690
691 #[test]
692 fn avro_rs_489_find_schema_with_known_schemata_type_promotion() -> TestResult {
693 let union = UnionSchema::new(vec![Schema::Long, Schema::Null])?;
694 let value = Value::Int(42);
695
696 assert_eq!(
697 union.find_schema_with_known_schemata(&value, None::<&HashMap<Name, Schema>>, None),
698 Some((0, &Schema::Long))
699 );
700
701 Ok(())
702 }
703
704 #[test]
705 fn avro_rs_489_find_schema_with_known_schemata_uuid_vs_fixed() -> TestResult {
706 let uuid = Schema::parse_str(
707 r#"{
708 "type": "fixed",
709 "logicalType": "uuid",
710 "name": "Uuid",
711 "size": 16
712 }"#,
713 )?;
714 let union = UnionSchema::new(vec![uuid.clone(), Schema::Null])?;
715 let value = Value::Fixed(16, vec![0; 16]);
716
717 assert_eq!(
718 union.find_schema_with_known_schemata(&value, None::<&HashMap<Name, Schema>>, None),
719 Some((0, &uuid))
720 );
721
722 Ok(())
723 }
724}