1use crate::error::Details;
19use crate::schema::{
20 DecimalSchema, InnerDecimalSchema, Name, NamespaceRef, Schema, SchemaKind, UuidSchema,
21};
22use crate::types;
23use crate::{AvroResult, Error};
24use std::borrow::Borrow;
25use std::collections::{BTreeMap, HashMap};
26use std::fmt::{Debug, Formatter};
27use strum::IntoDiscriminant;
28
29#[derive(Clone)]
31pub struct UnionSchema {
32 pub(crate) schemas: Vec<Schema>,
34 variant_index: BTreeMap<SchemaKind, usize>,
38 named_index: Vec<usize>,
42}
43
44impl Debug for UnionSchema {
45 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
46 f.debug_struct("UnionSchema")
48 .field("schemas", &self.schemas)
49 .finish()
50 }
51}
52
53impl UnionSchema {
54 pub fn new(schemas: Vec<Schema>) -> AvroResult<Self> {
60 let mut builder = Self::builder();
61 for schema in schemas {
62 builder.variant(schema)?;
63 }
64 Ok(builder.build())
65 }
66
67 pub fn builder() -> UnionSchemaBuilder {
69 UnionSchemaBuilder::new()
70 }
71
72 pub fn variants(&self) -> &[Schema] {
74 &self.schemas
75 }
76
77 pub fn is_nullable(&self) -> bool {
79 self.variant_index.contains_key(&SchemaKind::Null)
80 }
81
82 pub fn find_schema_with_known_schemata<S: Borrow<Schema> + Debug>(
88 &self,
89 value: &types::Value,
90 known_schemata: Option<&HashMap<Name, S>>,
91 enclosing_namespace: NamespaceRef,
92 ) -> Option<(usize, &Schema)> {
93 let known_schemata_if_none = HashMap::new();
94 let known_schemata = known_schemata.unwrap_or(&known_schemata_if_none);
95 let ValueSchemaKind { unnamed, named } = Self::value_to_base_schemakind(value);
96 let unnamed = unnamed
98 .and_then(|kind| self.variant_index.get(&kind).copied())
99 .map(|index| (index, &self.schemas[index]))
100 .and_then(|(index, schema)| {
101 let kind = schema.discriminant();
102 if kind == SchemaKind::Map || kind == SchemaKind::Array {
104 let namespace = schema.namespace().or(enclosing_namespace);
105
106 value
108 .clone()
109 .resolve_internal(schema, known_schemata, namespace, &None)
110 .ok()
111 .map(|_| (index, schema))
112 } else {
113 Some((index, schema))
114 }
115 });
116 let named = named.and_then(|kind| {
117 self.named_index
120 .iter()
121 .copied()
122 .map(|i| (i, &self.schemas[i]))
123 .filter(|(_i, s)| {
124 let s_kind = schema_to_base_schemakind(s);
125 s_kind == kind || s_kind == SchemaKind::Ref
126 })
127 .find(|(_i, schema)| {
128 let namespace = schema.namespace().or(enclosing_namespace);
129
130 value
132 .clone()
133 .resolve_internal(schema, known_schemata, namespace, &None)
134 .is_ok()
135 })
136 });
137
138 match (unnamed, named) {
139 (Some((u_i, _)), Some((n_i, _))) if u_i < n_i => unnamed,
140 (Some(_), Some(_)) => named,
141 (Some(_), None) => unnamed,
142 (None, Some(_)) => named,
143 (None, None) => {
144 self.schemas.iter().enumerate().find(|(_i, schema)| {
146 let namespace = schema.namespace().or(enclosing_namespace);
147
148 value
150 .clone()
151 .resolve_internal(schema, known_schemata, namespace, &None)
152 .is_ok()
153 })
154 }
155 }
156 }
157
158 fn value_to_base_schemakind(value: &types::Value) -> ValueSchemaKind {
160 let schemakind = SchemaKind::from(value);
161 match schemakind {
162 SchemaKind::Decimal => ValueSchemaKind {
163 unnamed: Some(SchemaKind::Bytes),
164 named: Some(SchemaKind::Fixed),
165 },
166 SchemaKind::BigDecimal => ValueSchemaKind {
167 unnamed: Some(SchemaKind::Bytes),
168 named: None,
169 },
170 SchemaKind::Uuid => ValueSchemaKind {
171 unnamed: Some(SchemaKind::String),
172 named: Some(SchemaKind::Fixed),
173 },
174 SchemaKind::Date | SchemaKind::TimeMillis => ValueSchemaKind {
175 unnamed: Some(SchemaKind::Int),
176 named: None,
177 },
178 SchemaKind::TimeMicros
179 | SchemaKind::TimestampMillis
180 | SchemaKind::TimestampMicros
181 | SchemaKind::TimestampNanos
182 | SchemaKind::LocalTimestampMillis
183 | SchemaKind::LocalTimestampMicros
184 | SchemaKind::LocalTimestampNanos => ValueSchemaKind {
185 unnamed: Some(SchemaKind::Long),
186 named: None,
187 },
188 SchemaKind::Duration => ValueSchemaKind {
189 unnamed: None,
190 named: Some(SchemaKind::Fixed),
191 },
192 SchemaKind::Record | SchemaKind::Enum | SchemaKind::Fixed => ValueSchemaKind {
193 unnamed: None,
194 named: Some(schemakind),
195 },
196 SchemaKind::Map => ValueSchemaKind {
199 unnamed: Some(SchemaKind::Map),
200 named: Some(SchemaKind::Record),
201 },
202 _ => ValueSchemaKind {
203 unnamed: Some(schemakind),
204 named: None,
205 },
206 }
207 }
208}
209
210struct ValueSchemaKind {
212 unnamed: Option<SchemaKind>,
213 named: Option<SchemaKind>,
214}
215
216impl PartialEq for UnionSchema {
218 fn eq(&self, other: &UnionSchema) -> bool {
219 self.schemas.eq(&other.schemas)
220 }
221}
222
223#[derive(Default, Debug)]
225pub struct UnionSchemaBuilder {
226 schemas: Vec<Schema>,
227 names: HashMap<Name, usize>,
228 variant_index: BTreeMap<SchemaKind, usize>,
229}
230
231impl UnionSchemaBuilder {
232 pub fn new() -> Self {
236 Self::default()
237 }
238
239 #[doc(hidden)]
240 pub fn variant_ignore_duplicates(&mut self, schema: Schema) -> Result<&mut Self, Error> {
248 if let Some(name) = schema.name() {
249 if let Some(current) = self.names.get(name).copied() {
250 if self.schemas[current] != schema {
251 return Err(Details::GetUnionDuplicateNamedSchemas(name.to_string()).into());
252 }
253 } else {
254 self.names.insert(name.clone(), self.schemas.len());
255 self.schemas.push(schema);
256 }
257 } else if let Schema::Map(_) = &schema {
258 if let Some(index) = self.variant_index.get(&SchemaKind::Map).copied() {
259 if self.schemas[index] != schema {
260 return Err(
261 Details::GetUnionDuplicateMap(self.schemas[index].clone(), schema).into(),
262 );
263 }
264 } else {
265 self.variant_index
266 .insert(SchemaKind::Map, self.schemas.len());
267 self.schemas.push(schema);
268 }
269 } else if let Schema::Array(_) = &schema {
270 if let Some(index) = self.variant_index.get(&SchemaKind::Array).copied() {
271 if self.schemas[index] != schema {
272 return Err(Details::GetUnionDuplicateArray(
273 self.schemas[index].clone(),
274 schema,
275 )
276 .into());
277 }
278 } else {
279 self.variant_index
280 .insert(SchemaKind::Array, self.schemas.len());
281 self.schemas.push(schema);
282 }
283 } else {
284 let discriminant = schema_to_base_schemakind(&schema);
285 if discriminant == SchemaKind::Union {
286 return Err(Details::GetNestedUnion.into());
287 }
288 if !self.variant_index.contains_key(&discriminant) {
289 self.variant_index.insert(discriminant, self.schemas.len());
290 self.schemas.push(schema);
291 }
292 }
293 Ok(self)
294 }
295
296 pub fn variant(&mut self, schema: Schema) -> Result<&mut Self, Error> {
302 if let Some(name) = schema.name() {
303 if self.names.contains_key(name) {
304 return Err(Details::GetUnionDuplicateNamedSchemas(name.to_string()).into());
305 } else {
306 self.names.insert(name.clone(), self.schemas.len());
307 self.schemas.push(schema);
308 }
309 } else {
310 let discriminant = schema_to_base_schemakind(&schema);
311 if discriminant == SchemaKind::Union {
312 return Err(Details::GetNestedUnion.into());
313 }
314 if self.variant_index.contains_key(&discriminant) {
315 return Err(Details::GetUnionDuplicate(discriminant).into());
316 } else {
317 self.variant_index.insert(discriminant, self.schemas.len());
318 self.schemas.push(schema);
319 }
320 }
321 Ok(self)
322 }
323
324 pub fn contains(&self, schema: &Schema) -> bool {
326 if let Some(name) = schema.name() {
327 if let Some(current) = self.names.get(name).copied() {
328 &self.schemas[current] == schema
329 } else {
330 false
331 }
332 } else {
333 let discriminant = schema_to_base_schemakind(schema);
334 if let Some(index) = self.variant_index.get(&discriminant).copied() {
335 &self.schemas[index] == schema
336 } else {
337 false
338 }
339 }
340 }
341
342 pub fn build(mut self) -> UnionSchema {
344 self.schemas.shrink_to_fit();
345 let mut named_index: Vec<_> = self.names.into_values().collect();
346 named_index.sort();
347 UnionSchema {
348 variant_index: self.variant_index,
349 named_index,
350 schemas: self.schemas,
351 }
352 }
353}
354
355fn schema_to_base_schemakind(schema: &Schema) -> SchemaKind {
357 let kind = schema.discriminant();
358 match kind {
359 SchemaKind::Date | SchemaKind::TimeMillis => SchemaKind::Int,
360 SchemaKind::TimeMicros
361 | SchemaKind::TimestampMillis
362 | SchemaKind::TimestampMicros
363 | SchemaKind::TimestampNanos
364 | SchemaKind::LocalTimestampMillis
365 | SchemaKind::LocalTimestampMicros
366 | SchemaKind::LocalTimestampNanos => SchemaKind::Long,
367 SchemaKind::Uuid => match schema {
368 Schema::Uuid(UuidSchema::Bytes) => SchemaKind::Bytes,
369 Schema::Uuid(UuidSchema::String) => SchemaKind::String,
370 Schema::Uuid(UuidSchema::Fixed(_)) => SchemaKind::Fixed,
371 _ => unreachable!(),
372 },
373 SchemaKind::Decimal => match schema {
374 Schema::Decimal(DecimalSchema {
375 inner: InnerDecimalSchema::Bytes,
376 ..
377 }) => SchemaKind::Bytes,
378 Schema::Decimal(DecimalSchema {
379 inner: InnerDecimalSchema::Fixed(_),
380 ..
381 }) => SchemaKind::Fixed,
382 _ => unreachable!(),
383 },
384 SchemaKind::Duration => SchemaKind::Fixed,
385 _ => kind,
386 }
387}
388
389#[cfg(test)]
390mod tests {
391 use super::*;
392 use crate::error::{Details, Error};
393 use crate::schema::RecordSchema;
394 use crate::types::Value;
395 use apache_avro_test_helper::TestResult;
396
397 #[test]
398 fn avro_rs_402_new_union_schema() -> TestResult {
399 let schema1 = Schema::Int;
400 let schema2 = Schema::String;
401 let union_schema = UnionSchema::new(vec![schema1.clone(), schema2.clone()])?;
402
403 assert_eq!(union_schema.variants(), &[schema1, schema2]);
404
405 Ok(())
406 }
407
408 #[test]
409 fn avro_rs_402_new_union_schema_duplicate_names() -> TestResult {
410 let res = UnionSchema::new(vec![
411 Schema::Record(RecordSchema::builder().try_name("Same_name")?.build()),
412 Schema::Record(RecordSchema::builder().try_name("Same_name")?.build()),
413 ])
414 .map_err(Error::into_details);
415
416 match res {
417 Err(Details::GetUnionDuplicateNamedSchemas(name)) => {
418 assert_eq!(name, Name::new("Same_name")?.to_string());
419 }
420 err => panic!("Expected GetUnionDuplicateNamedSchemas error, got: {err:?}"),
421 }
422
423 Ok(())
424 }
425
426 #[test]
427 fn avro_rs_489_union_schema_builder_primitive_type() -> TestResult {
428 let mut builder = UnionSchema::builder();
429 builder.variant(Schema::Null)?;
430 assert!(builder.variant(Schema::Null).is_err());
431 builder.variant_ignore_duplicates(Schema::Null)?;
432 builder.variant(Schema::Int)?;
433 assert!(builder.variant(Schema::Int).is_err());
434 builder.variant_ignore_duplicates(Schema::Int)?;
435 builder.variant(Schema::Long)?;
436 assert!(builder.variant(Schema::Long).is_err());
437 builder.variant_ignore_duplicates(Schema::Long)?;
438
439 let union = builder.build();
440 assert_eq!(union.schemas, &[Schema::Null, Schema::Int, Schema::Long]);
441
442 Ok(())
443 }
444
445 #[test]
446 fn avro_rs_489_union_schema_builder_complex_types() -> TestResult {
447 let enum_abc = Schema::parse_str(
448 r#"{
449 "type": "enum",
450 "name": "ABC",
451 "symbols": ["A", "B", "C"]
452 }"#,
453 )?;
454 let enum_abc_with_extra_symbol = Schema::parse_str(
455 r#"{
456 "type": "enum",
457 "name": "ABC",
458 "symbols": ["A", "B", "C", "D"]
459 }"#,
460 )?;
461 let enum_def = Schema::parse_str(
462 r#"{
463 "type": "enum",
464 "name": "DEF",
465 "symbols": ["D", "E", "F"]
466 }"#,
467 )?;
468 let fixed_abc = Schema::parse_str(
469 r#"{
470 "type": "fixed",
471 "name": "ABC",
472 "size": 1
473 }"#,
474 )?;
475 let fixed_foo = Schema::parse_str(
476 r#"{
477 "type": "fixed",
478 "name": "Foo",
479 "size": 1
480 }"#,
481 )?;
482
483 let mut builder = UnionSchema::builder();
484 builder.variant(enum_abc.clone())?;
485 assert!(builder.variant(enum_abc.clone()).is_err());
486 builder.variant_ignore_duplicates(enum_abc.clone())?;
487 assert!(builder.variant(fixed_abc.clone()).is_err());
489 assert!(
490 builder
491 .variant_ignore_duplicates(fixed_abc.clone())
492 .is_err()
493 );
494 assert!(builder.variant(enum_abc_with_extra_symbol.clone()).is_err());
496 assert!(
497 builder
498 .variant_ignore_duplicates(enum_abc_with_extra_symbol.clone())
499 .is_err()
500 );
501 builder.variant(enum_def.clone())?;
502 assert!(builder.variant(enum_def.clone()).is_err());
503 builder.variant_ignore_duplicates(enum_def.clone())?;
504 builder.variant(fixed_foo.clone())?;
505 assert!(builder.variant(fixed_foo.clone()).is_err());
506 builder.variant_ignore_duplicates(fixed_foo.clone())?;
507
508 let union = builder.build();
509 assert_eq!(union.variants(), &[enum_abc, enum_def, fixed_foo]);
510
511 Ok(())
512 }
513
514 #[test]
515 fn avro_rs_489_union_schema_builder_logical_types() -> TestResult {
516 let fixed_uuid = Schema::parse_str(
517 r#"{
518 "type": "fixed",
519 "name": "Uuid",
520 "size": 16
521 }"#,
522 )?;
523 let uuid = Schema::parse_str(
524 r#"{
525 "type": "fixed",
526 "logicalType": "uuid",
527 "name": "Uuid",
528 "size": 16
529 }"#,
530 )?;
531
532 let mut builder = UnionSchema::builder();
533
534 builder.variant(Schema::Date)?;
535 assert!(builder.variant(Schema::Date).is_err());
536 builder.variant_ignore_duplicates(Schema::Date)?;
537 assert!(builder.variant(Schema::Int).is_err());
538 builder.variant_ignore_duplicates(Schema::Int)?;
539 builder.variant(uuid.clone())?;
540 assert!(builder.variant(uuid.clone()).is_err());
541 builder.variant_ignore_duplicates(uuid.clone())?;
542 assert!(builder.variant(fixed_uuid.clone()).is_err());
543 assert!(
544 builder
545 .variant_ignore_duplicates(fixed_uuid.clone())
546 .is_err()
547 );
548
549 let union = builder.build();
550 assert_eq!(union.schemas, &[Schema::Date, uuid]);
551
552 Ok(())
553 }
554
555 #[test]
556 fn avro_rs_489_find_schema_with_known_schemata_wrong_map() -> TestResult {
557 let union = UnionSchema::new(vec![Schema::map(Schema::Int).build(), Schema::Null])?;
558 let value = Value::Map(
559 [("key".to_string(), Value::String("value".to_string()))]
560 .into_iter()
561 .collect(),
562 );
563
564 assert!(
565 union
566 .find_schema_with_known_schemata(&value, None::<&HashMap<Name, Schema>>, None)
567 .is_none()
568 );
569
570 Ok(())
571 }
572
573 #[test]
574 fn avro_rs_489_find_schema_with_known_schemata_type_promotion() -> TestResult {
575 let union = UnionSchema::new(vec![Schema::Long, Schema::Null])?;
576 let value = Value::Int(42);
577
578 assert_eq!(
579 union.find_schema_with_known_schemata(&value, None::<&HashMap<Name, Schema>>, None),
580 Some((0, &Schema::Long))
581 );
582
583 Ok(())
584 }
585
586 #[test]
587 fn avro_rs_489_find_schema_with_known_schemata_uuid_vs_fixed() -> TestResult {
588 let uuid = Schema::parse_str(
589 r#"{
590 "type": "fixed",
591 "logicalType": "uuid",
592 "name": "Uuid",
593 "size": 16
594 }"#,
595 )?;
596 let union = UnionSchema::new(vec![uuid.clone(), Schema::Null])?;
597 let value = Value::Fixed(16, vec![0; 16]);
598
599 assert_eq!(
600 union.find_schema_with_known_schemata(&value, None::<&HashMap<Name, Schema>>, None),
601 Some((0, &uuid))
602 );
603
604 Ok(())
605 }
606}