apache_avro_derive/
lib.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18#![cfg_attr(nightly, feature(proc_macro_diagnostic))]
19
20//! This crate is the implementation of the `AvroSchema` derive macro.
21//! Please use it via the [`apache-avro`](https://crates.io/crates/apache-avro) crate:
22//!
23//! ```no_run
24//! use apache_avro::AvroSchema;
25//!
26//! #[derive(AvroSchema)]
27//! ```
28//! Please see the documentation of the [`AvroSchema`] trait for instructions on how to use it.
29//!
30//! [`AvroSchema`]: https://docs.rs/apache-avro/latest/apache_avro/serde/trait.AvroSchema.html
31
32mod attributes;
33mod case;
34
35use proc_macro2::{Span, TokenStream};
36use quote::quote;
37use syn::{
38    Attribute, DataEnum, DataStruct, DeriveInput, Expr, Field, Fields, Generics, Ident, Meta, Type,
39    parse_macro_input, spanned::Spanned,
40};
41
42use crate::{
43    attributes::{FieldOptions, NamedTypeOptions, VariantOptions, With},
44    case::RenameRule,
45};
46
47#[proc_macro_derive(AvroSchema, attributes(avro, serde))]
48// Templated from Serde
49pub fn proc_macro_derive_avro_schema(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
50    let input = parse_macro_input!(input as DeriveInput);
51    derive_avro_schema(input)
52        .unwrap_or_else(to_compile_errors)
53        .into()
54}
55
56fn derive_avro_schema(input: DeriveInput) -> Result<TokenStream, Vec<syn::Error>> {
57    // It would be nice to parse the attributes before the `match`, but we first need to validate that `input` is not a union.
58    // Otherwise a user could get errors related to the attributes and after fixing those get an error because the attributes were on a union.
59    let input_span = input.span();
60    match input.data {
61        syn::Data::Struct(data_struct) => {
62            let named_type_options = NamedTypeOptions::new(&input.ident, &input.attrs, input_span)?;
63            let (get_schema_impl, get_record_fields_impl) = if named_type_options.transparent {
64                get_transparent_struct_schema_def(data_struct.fields, input_span)?
65            } else {
66                let (schema_def, record_fields) =
67                    get_struct_schema_def(&named_type_options, data_struct, input.ident.span())?;
68                (
69                    handle_named_schemas(named_type_options.name, schema_def),
70                    record_fields,
71                )
72            };
73            Ok(create_trait_definition(
74                input.ident,
75                &input.generics,
76                get_schema_impl,
77                get_record_fields_impl,
78            ))
79        }
80        syn::Data::Enum(data_enum) => {
81            let named_type_options = NamedTypeOptions::new(&input.ident, &input.attrs, input_span)?;
82            if named_type_options.transparent {
83                return Err(vec![syn::Error::new(
84                    input_span,
85                    "AvroSchema: `#[serde(transparent)]` is only supported on structs",
86                )]);
87            }
88            let schema_def =
89                get_data_enum_schema_def(&named_type_options, data_enum, input.ident.span())?;
90            let inner = handle_named_schemas(named_type_options.name, schema_def);
91            Ok(create_trait_definition(
92                input.ident,
93                &input.generics,
94                inner,
95                quote! { None },
96            ))
97        }
98        syn::Data::Union(_) => Err(vec![syn::Error::new(
99            input_span,
100            "AvroSchema: derive only works for structs and simple enums",
101        )]),
102    }
103}
104
105/// Generate the trait definition with the correct generics
106fn create_trait_definition(
107    ident: Ident,
108    generics: &Generics,
109    get_schema_impl: TokenStream,
110    get_record_fields_impl: TokenStream,
111) -> TokenStream {
112    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
113    quote! {
114        #[automatically_derived]
115        impl #impl_generics ::apache_avro::AvroSchemaComponent for #ident #ty_generics #where_clause {
116            fn get_schema_in_ctxt(named_schemas: &mut ::std::collections::HashSet<::apache_avro::schema::Name>, enclosing_namespace: &::std::option::Option<::std::string::String>) -> ::apache_avro::schema::Schema {
117                #get_schema_impl
118            }
119
120            fn get_record_fields_in_ctxt(mut field_position: usize, named_schemas: &mut ::std::collections::HashSet<::apache_avro::schema::Name>, enclosing_namespace: &::std::option::Option<::std::string::String>) -> ::std::option::Option<::std::vec::Vec<::apache_avro::schema::RecordField>> {
121                #get_record_fields_impl
122            }
123        }
124    }
125}
126
127/// Generate the code to check `named_schemas` if this schema already exist
128fn handle_named_schemas(full_schema_name: String, schema_def: TokenStream) -> TokenStream {
129    quote! {
130        let name = apache_avro::schema::Name::new(#full_schema_name).expect(concat!("Unable to parse schema name ", #full_schema_name)).fully_qualified_name(enclosing_namespace);
131        if named_schemas.contains(&name) {
132            apache_avro::schema::Schema::Ref{name}
133        } else {
134            let enclosing_namespace = &name.namespace;
135            named_schemas.insert(name.clone());
136            #schema_def
137        }
138    }
139}
140
141/// Generate a schema definition for a struct.
142fn get_struct_schema_def(
143    container_attrs: &NamedTypeOptions,
144    data_struct: DataStruct,
145    ident_span: Span,
146) -> Result<(TokenStream, TokenStream), Vec<syn::Error>> {
147    let mut record_field_exprs = vec![];
148    match data_struct.fields {
149        Fields::Named(a) => {
150            for field in a.named {
151                let mut name = field
152                    .ident
153                    .as_ref()
154                    .expect("Field must have a name")
155                    .to_string();
156                if let Some(raw_name) = name.strip_prefix("r#") {
157                    name = raw_name.to_string();
158                }
159                let field_attrs = FieldOptions::new(&field.attrs, field.span())?;
160                let doc = preserve_optional(field_attrs.doc);
161                match (field_attrs.rename, container_attrs.rename_all) {
162                    (Some(rename), _) => {
163                        name = rename;
164                    }
165                    (None, rename_all) if rename_all != RenameRule::None => {
166                        name = rename_all.apply_to_field(&name);
167                    }
168                    _ => {}
169                }
170                if field_attrs.skip {
171                    continue;
172                } else if field_attrs.flatten {
173                    // Inline the fields of the child record at runtime, as we don't have access to
174                    // the schema here.
175                    let get_record_fields =
176                        get_field_get_record_fields_expr(&field, field_attrs.with)?;
177                    record_field_exprs.push(quote! {
178                        if let Some(flattened_fields) = #get_record_fields {
179                            field_position += flattened_fields.len();
180                            schema_fields.extend(flattened_fields);
181                        } else {
182                            panic!("{} does not have any fields to flatten to", stringify!(#field));
183                        }
184                    });
185
186                    // Don't add this field as it's been replaced by the child record fields
187                    continue;
188                }
189                let default_value = match field_attrs.default {
190                    Some(default_value) => {
191                        let _: serde_json::Value = serde_json::from_str(&default_value[..])
192                            .map_err(|e| {
193                                vec![syn::Error::new(
194                                    field.ident.span(),
195                                    format!("Invalid avro default json: \n{e}"),
196                                )]
197                            })?;
198                        quote! {
199                            Some(serde_json::from_str(#default_value).expect(format!("Invalid JSON: {:?}", #default_value).as_str()))
200                        }
201                    }
202                    None => quote! { None },
203                };
204                let aliases = aliases(&field_attrs.alias);
205                let schema_expr = get_field_schema_expr(&field, field_attrs.with)?;
206                record_field_exprs.push(quote! {
207                    schema_fields.push(::apache_avro::schema::RecordField {
208                        name: #name.to_string(),
209                        doc: #doc,
210                        default: #default_value,
211                        aliases: #aliases,
212                        schema: #schema_expr,
213                        order: ::apache_avro::schema::RecordFieldOrder::Ascending,
214                        position: field_position,
215                        custom_attributes: Default::default(),
216                    });
217                    field_position += 1;
218                });
219            }
220        }
221        Fields::Unnamed(_) => {
222            return Err(vec![syn::Error::new(
223                ident_span,
224                "AvroSchema derive does not work for tuple structs",
225            )]);
226        }
227        Fields::Unit => {
228            return Err(vec![syn::Error::new(
229                ident_span,
230                "AvroSchema derive does not work for unit structs",
231            )]);
232        }
233    }
234
235    let record_doc = preserve_optional(container_attrs.doc.as_ref());
236    let record_aliases = aliases(&container_attrs.aliases);
237    let full_schema_name = &container_attrs.name;
238
239    // When flatten is involved, there will be more but we don't know how many. This optimises for
240    // the most common case where there is no flatten.
241    let minimum_fields = record_field_exprs.len();
242
243    let schema_def = quote! {
244        {
245            let mut schema_fields = Vec::with_capacity(#minimum_fields);
246            let mut field_position = 0;
247            #(#record_field_exprs)*
248            let schema_field_set: ::std::collections::HashSet<_> = schema_fields.iter().map(|rf| &rf.name).collect();
249            assert_eq!(schema_fields.len(), schema_field_set.len(), "Duplicate field names found: {schema_fields:?}");
250            let name = apache_avro::schema::Name::new(#full_schema_name).expect(&format!("Unable to parse struct name for schema {}", #full_schema_name)[..]);
251            let lookup: std::collections::BTreeMap<String, usize> = schema_fields
252                .iter()
253                .map(|field| (field.name.to_owned(), field.position))
254                .collect();
255            apache_avro::schema::Schema::Record(apache_avro::schema::RecordSchema {
256                name,
257                aliases: #record_aliases,
258                doc: #record_doc,
259                fields: schema_fields,
260                lookup,
261                attributes: Default::default(),
262            })
263        }
264    };
265    let record_fields = quote! {
266        let mut schema_fields = Vec::with_capacity(#minimum_fields);
267        #(#record_field_exprs)*
268        Some(schema_fields)
269    };
270
271    Ok((schema_def, record_fields))
272}
273
274/// Use the schema definition of the only field in the struct as the schema
275fn get_transparent_struct_schema_def(
276    fields: Fields,
277    input_span: Span,
278) -> Result<(TokenStream, TokenStream), Vec<syn::Error>> {
279    match fields {
280        Fields::Named(fields_named) => {
281            let mut found = None;
282            for field in fields_named.named {
283                let attrs = FieldOptions::new(&field.attrs, field.span())?;
284                if attrs.skip {
285                    continue;
286                }
287                if found.replace((field, attrs)).is_some() {
288                    return Err(vec![syn::Error::new(
289                        input_span,
290                        "AvroSchema: #[serde(transparent)] is only allowed on structs with one unskipped field",
291                    )]);
292                }
293            }
294
295            if let Some((field, attrs)) = found {
296                Ok((
297                    get_field_schema_expr(&field, attrs.with.clone())?,
298                    get_field_get_record_fields_expr(&field, attrs.with)?,
299                ))
300            } else {
301                Err(vec![syn::Error::new(
302                    input_span,
303                    "AvroSchema: #[serde(transparent)] is only allowed on structs with one unskipped field",
304                )])
305            }
306        }
307        Fields::Unnamed(_) => Err(vec![syn::Error::new(
308            input_span,
309            "AvroSchema: derive does not work for tuple structs",
310        )]),
311        Fields::Unit => Err(vec![syn::Error::new(
312            input_span,
313            "AvroSchema: derive does not work for unit structs",
314        )]),
315    }
316}
317
318fn get_field_schema_expr(field: &Field, with: With) -> Result<TokenStream, Vec<syn::Error>> {
319    match with {
320        With::Trait => Ok(type_to_schema_expr(&field.ty)?),
321        With::Serde(path) => {
322            Ok(quote! { #path::get_schema_in_ctxt(named_schemas, enclosing_namespace) })
323        }
324        With::Expr(Expr::Closure(closure)) => {
325            if closure.inputs.is_empty() {
326                Ok(quote! { (#closure)() })
327            } else {
328                Err(vec![syn::Error::new(
329                    field.span(),
330                    "Expected closure with 0 parameters",
331                )])
332            }
333        }
334        With::Expr(Expr::Path(path)) => Ok(quote! { #path(named_schemas, enclosing_namespace) }),
335        With::Expr(_expr) => Err(vec![syn::Error::new(
336            field.span(),
337            "Invalid expression, expected function or closure",
338        )]),
339    }
340}
341
342fn get_field_get_record_fields_expr(
343    field: &Field,
344    with: With,
345) -> Result<TokenStream, Vec<syn::Error>> {
346    match with {
347        With::Trait => Ok(type_to_get_record_fields_expr(&field.ty)?),
348        With::Serde(path) => Ok(
349            quote! { #path::get_record_fields_in_ctxt(field_position, named_schemas, enclosing_namespace) },
350        ),
351        With::Expr(Expr::Closure(closure)) => {
352            if closure.inputs.is_empty() {
353                Ok(quote! {
354                    ::apache_avro::serde::get_record_fields_in_ctxt(
355                        field_position,
356                        named_schemas,
357                        enclosing_namespace,
358                        |_, _| (#closure)(),
359                    )
360                })
361            } else {
362                Err(vec![syn::Error::new(
363                    field.span(),
364                    "Expected closure with 0 parameters",
365                )])
366            }
367        }
368        With::Expr(Expr::Path(path)) => Ok(quote! {
369            ::apache_avro::serde::get_record_fields_in_ctxt(field_position, named_schemas, enclosing_namespace, #path)
370        }),
371        With::Expr(_expr) => Err(vec![syn::Error::new(
372            field.span(),
373            "Invalid expression, expected function or closure",
374        )]),
375    }
376}
377
378/// Generate a schema definition for a enum.
379fn get_data_enum_schema_def(
380    container_attrs: &NamedTypeOptions,
381    data_enum: DataEnum,
382    ident_span: Span,
383) -> Result<TokenStream, Vec<syn::Error>> {
384    let doc = preserve_optional(container_attrs.doc.as_ref());
385    let enum_aliases = aliases(&container_attrs.aliases);
386    if data_enum.variants.iter().all(|v| Fields::Unit == v.fields) {
387        let default_value = default_enum_variant(&data_enum, ident_span)?;
388        let default = preserve_optional(default_value);
389        let mut symbols = Vec::new();
390        for variant in &data_enum.variants {
391            let field_attrs = VariantOptions::new(&variant.attrs, variant.span())?;
392            let name = match (field_attrs.rename, container_attrs.rename_all) {
393                (Some(rename), _) => rename,
394                (None, rename_all) if !matches!(rename_all, RenameRule::None) => {
395                    rename_all.apply_to_variant(&variant.ident.to_string())
396                }
397                _ => variant.ident.to_string(),
398            };
399            symbols.push(name);
400        }
401        let full_schema_name = &container_attrs.name;
402        Ok(quote! {
403            apache_avro::schema::Schema::Enum(apache_avro::schema::EnumSchema {
404                name: apache_avro::schema::Name::new(#full_schema_name).expect(&format!("Unable to parse enum name for schema {}", #full_schema_name)[..]),
405                aliases: #enum_aliases,
406                doc: #doc,
407                symbols: vec![#(#symbols.to_owned()),*],
408                default: #default,
409                attributes: Default::default(),
410            })
411        })
412    } else {
413        Err(vec![syn::Error::new(
414            ident_span,
415            "AvroSchema: derive does not work for enums with non unit structs",
416        )])
417    }
418}
419
420/// Takes in the Tokens of a type and returns the tokens of an expression with return type `Schema`
421fn type_to_schema_expr(ty: &Type) -> Result<TokenStream, Vec<syn::Error>> {
422    match ty {
423        Type::Array(_) | Type::Slice(_) | Type::Path(_) | Type::Reference(_) => Ok(
424            quote! {<#ty as apache_avro::AvroSchemaComponent>::get_schema_in_ctxt(named_schemas, enclosing_namespace)},
425        ),
426        Type::Ptr(_) => Err(vec![syn::Error::new_spanned(
427            ty,
428            "AvroSchema: derive does not support raw pointers",
429        )]),
430        Type::Tuple(_) => Err(vec![syn::Error::new_spanned(
431            ty,
432            "AvroSchema: derive does not support tuples",
433        )]),
434        _ => Err(vec![syn::Error::new_spanned(
435            ty,
436            format!(
437                "AvroSchema: Unexpected type encountered! Please open an issue if this kind of type should be supported: {ty:?}"
438            ),
439        )]),
440    }
441}
442
443fn type_to_get_record_fields_expr(ty: &Type) -> Result<TokenStream, Vec<syn::Error>> {
444    match ty {
445        Type::Array(_) | Type::Slice(_) | Type::Path(_) | Type::Reference(_) => Ok(
446            quote! {<#ty as apache_avro::AvroSchemaComponent>::get_record_fields_in_ctxt(field_position, named_schemas, enclosing_namespace)},
447        ),
448        Type::Ptr(_) => Err(vec![syn::Error::new_spanned(
449            ty,
450            "AvroSchema: derive does not support raw pointers",
451        )]),
452        Type::Tuple(_) => Err(vec![syn::Error::new_spanned(
453            ty,
454            "AvroSchema: derive does not support tuples",
455        )]),
456        _ => Err(vec![syn::Error::new_spanned(
457            ty,
458            format!(
459                "AvroSchema: Unexpected type encountered! Please open an issue if this kind of type should be supported: {ty:?}"
460            ),
461        )]),
462    }
463}
464
465fn default_enum_variant(
466    data_enum: &syn::DataEnum,
467    error_span: Span,
468) -> Result<Option<String>, Vec<syn::Error>> {
469    match data_enum
470        .variants
471        .iter()
472        .filter(|v| v.attrs.iter().any(is_default_attr))
473        .collect::<Vec<_>>()
474    {
475        variants if variants.is_empty() => Ok(None),
476        single if single.len() == 1 => Ok(Some(single[0].ident.to_string())),
477        multiple => Err(vec![syn::Error::new(
478            error_span,
479            format!(
480                "Multiple defaults defined: {:?}",
481                multiple
482                    .iter()
483                    .map(|v| v.ident.to_string())
484                    .collect::<Vec<String>>()
485            ),
486        )]),
487    }
488}
489
490fn is_default_attr(attr: &Attribute) -> bool {
491    matches!(attr, Attribute { meta: Meta::Path(path), .. } if path.get_ident().map(Ident::to_string).as_deref() == Some("default"))
492}
493
494/// Stolen from serde
495fn to_compile_errors(errors: Vec<syn::Error>) -> proc_macro2::TokenStream {
496    let compile_errors = errors.iter().map(syn::Error::to_compile_error);
497    quote!(#(#compile_errors)*)
498}
499
500fn preserve_optional(op: Option<impl quote::ToTokens>) -> TokenStream {
501    match op {
502        Some(tt) => quote! {Some(#tt.into())},
503        None => quote! {None},
504    }
505}
506
507fn aliases(op: &[impl quote::ToTokens]) -> TokenStream {
508    let items: Vec<TokenStream> = op
509        .iter()
510        .map(|tt| quote! {#tt.try_into().expect("Alias is invalid")})
511        .collect();
512    if items.is_empty() {
513        quote! {None}
514    } else {
515        quote! {Some(vec![#(#items),*])}
516    }
517}
518
519#[cfg(test)]
520mod tests {
521    use super::*;
522    use pretty_assertions::assert_eq;
523
524    #[test]
525    fn basic_case() {
526        let test_struct = quote! {
527            struct A {
528                a: i32,
529                b: String
530            }
531        };
532
533        match syn::parse2::<DeriveInput>(test_struct) {
534            Ok(input) => {
535                assert!(derive_avro_schema(input).is_ok())
536            }
537            Err(error) => panic!(
538                "Failed to parse as derive input when it should be able to. Error: {error:?}"
539            ),
540        };
541    }
542
543    #[test]
544    fn tuple_struct_unsupported() {
545        let test_tuple_struct = quote! {
546            struct B (i32, String);
547        };
548
549        match syn::parse2::<DeriveInput>(test_tuple_struct) {
550            Ok(input) => {
551                assert!(derive_avro_schema(input).is_err())
552            }
553            Err(error) => panic!(
554                "Failed to parse as derive input when it should be able to. Error: {error:?}"
555            ),
556        };
557    }
558
559    #[test]
560    fn unit_struct_unsupported() {
561        let test_tuple_struct = quote! {
562            struct AbsoluteUnit;
563        };
564
565        match syn::parse2::<DeriveInput>(test_tuple_struct) {
566            Ok(input) => {
567                assert!(derive_avro_schema(input).is_err())
568            }
569            Err(error) => panic!(
570                "Failed to parse as derive input when it should be able to. Error: {error:?}"
571            ),
572        };
573    }
574
575    #[test]
576    fn struct_with_optional() {
577        let struct_with_optional = quote! {
578            struct Test4 {
579                a : Option<i32>
580            }
581        };
582        match syn::parse2::<DeriveInput>(struct_with_optional) {
583            Ok(input) => {
584                assert!(derive_avro_schema(input).is_ok())
585            }
586            Err(error) => panic!(
587                "Failed to parse as derive input when it should be able to. Error: {error:?}"
588            ),
589        };
590    }
591
592    #[test]
593    fn test_basic_enum() {
594        let basic_enum = quote! {
595            enum Basic {
596                A,
597                B,
598                C,
599                D
600            }
601        };
602        match syn::parse2::<DeriveInput>(basic_enum) {
603            Ok(input) => {
604                assert!(derive_avro_schema(input).is_ok())
605            }
606            Err(error) => panic!(
607                "Failed to parse as derive input when it should be able to. Error: {error:?}"
608            ),
609        };
610    }
611
612    #[test]
613    fn avro_3687_basic_enum_with_default() {
614        let basic_enum = quote! {
615            enum Basic {
616                #[default]
617                A,
618                B,
619                C,
620                D
621            }
622        };
623        match syn::parse2::<DeriveInput>(basic_enum) {
624            Ok(input) => {
625                let derived = derive_avro_schema(input);
626                assert!(derived.is_ok());
627                assert_eq!(derived.unwrap().to_string(), quote! {
628                    #[automatically_derived]
629                    impl ::apache_avro::AvroSchemaComponent for Basic {
630                        fn get_schema_in_ctxt(
631                            named_schemas: &mut ::std::collections::HashSet<::apache_avro::schema::Name>,
632                            enclosing_namespace: &::std::option::Option<::std::string::String>
633                        ) -> ::apache_avro::schema::Schema {
634                            let name = apache_avro::schema::Name::new("Basic")
635                                .expect(concat!("Unable to parse schema name ", "Basic"))
636                                .fully_qualified_name(enclosing_namespace);
637                            if named_schemas.contains(&name) {
638                                apache_avro::schema::Schema::Ref { name }
639                            } else {
640                                let enclosing_namespace = &name.namespace;
641                                named_schemas.insert(name.clone());
642                                apache_avro::schema::Schema::Enum(apache_avro::schema::EnumSchema {
643                                    name: apache_avro::schema::Name::new("Basic").expect(
644                                        &format!("Unable to parse enum name for schema {}", "Basic")[..]
645                                    ),
646                                    aliases: None,
647                                    doc: None,
648                                    symbols: vec![
649                                        "A".to_owned(),
650                                        "B".to_owned(),
651                                        "C".to_owned(),
652                                        "D".to_owned()
653                                    ],
654                                    default: Some("A".into()),
655                                    attributes: Default::default(),
656                                })
657                            }
658                        }
659
660                        fn get_record_fields_in_ctxt(
661                            mut field_position: usize,
662                            named_schemas: &mut ::std::collections::HashSet<::apache_avro::schema::Name>,
663                            enclosing_namespace: &::std::option::Option<::std::string::String>
664                        ) -> ::std::option::Option <::std::vec::Vec<::apache_avro::schema::RecordField>> {
665                            None
666                        }
667                    }
668                }.to_string());
669            }
670            Err(error) => panic!(
671                "Failed to parse as derive input when it should be able to. Error: {error:?}"
672            ),
673        };
674    }
675
676    #[test]
677    fn avro_3687_basic_enum_with_default_twice() {
678        let non_basic_enum = quote! {
679            enum Basic {
680                #[default]
681                A,
682                B,
683                #[default]
684                C,
685                D
686            }
687        };
688        match syn::parse2::<DeriveInput>(non_basic_enum) {
689            Ok(input) => match derive_avro_schema(input) {
690                Ok(_) => {
691                    panic!("Should not be able to derive schema for enum with multiple defaults")
692                }
693                Err(errors) => {
694                    assert_eq!(errors.len(), 1);
695                    assert_eq!(
696                        errors[0].to_string(),
697                        r#"Multiple defaults defined: ["A", "C"]"#
698                    );
699                }
700            },
701            Err(error) => panic!(
702                "Failed to parse as derive input when it should be able to. Error: {error:?}"
703            ),
704        };
705    }
706
707    #[test]
708    fn test_non_basic_enum() {
709        let non_basic_enum = quote! {
710            enum Basic {
711                A(i32),
712                B,
713                C,
714                D
715            }
716        };
717        match syn::parse2::<DeriveInput>(non_basic_enum) {
718            Ok(input) => {
719                assert!(derive_avro_schema(input).is_err())
720            }
721            Err(error) => panic!(
722                "Failed to parse as derive input when it should be able to. Error: {error:?}"
723            ),
724        };
725    }
726
727    #[test]
728    fn test_namespace() {
729        let test_struct = quote! {
730            #[avro(namespace = "namespace.testing")]
731            struct A {
732                a: i32,
733                b: String
734            }
735        };
736
737        match syn::parse2::<DeriveInput>(test_struct) {
738            Ok(input) => {
739                let schema_token_stream = derive_avro_schema(input);
740                assert!(&schema_token_stream.is_ok());
741                assert!(
742                    schema_token_stream
743                        .unwrap()
744                        .to_string()
745                        .contains("namespace.testing")
746                )
747            }
748            Err(error) => panic!(
749                "Failed to parse as derive input when it should be able to. Error: {error:?}"
750            ),
751        };
752    }
753
754    #[test]
755    fn test_reference() {
756        let test_reference_struct = quote! {
757            struct A<'a> {
758                a: &'a Vec<i32>,
759                b: &'static str
760            }
761        };
762
763        match syn::parse2::<DeriveInput>(test_reference_struct) {
764            Ok(input) => {
765                assert!(derive_avro_schema(input).is_ok())
766            }
767            Err(error) => panic!(
768                "Failed to parse as derive input when it should be able to. Error: {error:?}"
769            ),
770        };
771    }
772
773    #[test]
774    fn test_trait_cast() {
775        assert_eq!(type_to_schema_expr(&syn::parse2::<Type>(quote!{i32}).unwrap()).unwrap().to_string(), quote!{<i32 as apache_avro::AvroSchemaComponent>::get_schema_in_ctxt(named_schemas, enclosing_namespace)}.to_string());
776        assert_eq!(type_to_schema_expr(&syn::parse2::<Type>(quote!{Vec<T>}).unwrap()).unwrap().to_string(), quote!{<Vec<T> as apache_avro::AvroSchemaComponent>::get_schema_in_ctxt(named_schemas, enclosing_namespace)}.to_string());
777        assert_eq!(type_to_schema_expr(&syn::parse2::<Type>(quote!{AnyType}).unwrap()).unwrap().to_string(), quote!{<AnyType as apache_avro::AvroSchemaComponent>::get_schema_in_ctxt(named_schemas, enclosing_namespace)}.to_string());
778    }
779
780    #[test]
781    fn test_avro_3709_record_field_attributes() {
782        let test_struct = quote! {
783            struct A {
784                #[serde(alias = "a1", alias = "a2", rename = "a3")]
785                #[avro(doc = "a doc", default = "123")]
786                a: i32
787            }
788        };
789
790        match syn::parse2::<DeriveInput>(test_struct) {
791            Ok(input) => {
792                let schema_res = derive_avro_schema(input);
793                let expected_token_stream = r#"# [automatically_derived] impl :: apache_avro :: AvroSchemaComponent for A { fn get_schema_in_ctxt (named_schemas : & mut :: std :: collections :: HashSet < :: apache_avro :: schema :: Name > , enclosing_namespace : & :: std :: option :: Option < :: std :: string :: String >) -> :: apache_avro :: schema :: Schema { let name = apache_avro :: schema :: Name :: new ("A") . expect (concat ! ("Unable to parse schema name " , "A")) . fully_qualified_name (enclosing_namespace) ; if named_schemas . contains (& name) { apache_avro :: schema :: Schema :: Ref { name } } else { let enclosing_namespace = & name . namespace ; named_schemas . insert (name . clone ()) ; { let mut schema_fields = Vec :: with_capacity (1usize) ; let mut field_position = 0 ; schema_fields . push (:: apache_avro :: schema :: RecordField { name : "a3" . to_string () , doc : Some ("a doc" . into ()) , default : Some (serde_json :: from_str ("123") . expect (format ! ("Invalid JSON: {:?}" , "123") . as_str ())) , aliases : Some (vec ! ["a1" . try_into () . expect ("Alias is invalid") , "a2" . try_into () . expect ("Alias is invalid")]) , schema : < i32 as apache_avro :: AvroSchemaComponent > :: get_schema_in_ctxt (named_schemas , enclosing_namespace) , order : :: apache_avro :: schema :: RecordFieldOrder :: Ascending , position : field_position , custom_attributes : Default :: default () , }) ; field_position += 1 ; let schema_field_set : :: std :: collections :: HashSet < _ > = schema_fields . iter () . map (| rf | & rf . name) . collect () ; assert_eq ! (schema_fields . len () , schema_field_set . len () , "Duplicate field names found: {schema_fields:?}") ; let name = apache_avro :: schema :: Name :: new ("A") . expect (& format ! ("Unable to parse struct name for schema {}" , "A") [..]) ; let lookup : std :: collections :: BTreeMap < String , usize > = schema_fields . iter () . map (| field | (field . name . to_owned () , field . position)) . collect () ; apache_avro :: schema :: Schema :: Record (apache_avro :: schema :: RecordSchema { name , aliases : None , doc : None , fields : schema_fields , lookup , attributes : Default :: default () , }) } } } fn get_record_fields_in_ctxt (mut field_position : usize , named_schemas : & mut :: std :: collections :: HashSet < :: apache_avro :: schema :: Name > , enclosing_namespace : & :: std :: option :: Option < :: std :: string :: String >) -> :: std :: option :: Option < :: std :: vec :: Vec < :: apache_avro :: schema :: RecordField >> { let mut schema_fields = Vec :: with_capacity (1usize) ; schema_fields . push (:: apache_avro :: schema :: RecordField { name : "a3" . to_string () , doc : Some ("a doc" . into ()) , default : Some (serde_json :: from_str ("123") . expect (format ! ("Invalid JSON: {:?}" , "123") . as_str ())) , aliases : Some (vec ! ["a1" . try_into () . expect ("Alias is invalid") , "a2" . try_into () . expect ("Alias is invalid")]) , schema : < i32 as apache_avro :: AvroSchemaComponent > :: get_schema_in_ctxt (named_schemas , enclosing_namespace) , order : :: apache_avro :: schema :: RecordFieldOrder :: Ascending , position : field_position , custom_attributes : Default :: default () , }) ; field_position += 1 ; Some (schema_fields) } }"#;
794                let schema_token_stream = schema_res.unwrap().to_string();
795                assert_eq!(schema_token_stream, expected_token_stream);
796            }
797            Err(error) => panic!(
798                "Failed to parse as derive input when it should be able to. Error: {error:?}"
799            ),
800        };
801
802        let test_enum = quote! {
803            enum A {
804                #[serde(rename = "A3")]
805                Item1,
806            }
807        };
808
809        match syn::parse2::<DeriveInput>(test_enum) {
810            Ok(input) => {
811                let schema_res = derive_avro_schema(input);
812                let expected_token_stream = r#"# [automatically_derived] impl :: apache_avro :: AvroSchemaComponent for A { fn get_schema_in_ctxt (named_schemas : & mut :: std :: collections :: HashSet < :: apache_avro :: schema :: Name > , enclosing_namespace : & :: std :: option :: Option < :: std :: string :: String >) -> :: apache_avro :: schema :: Schema { let name = apache_avro :: schema :: Name :: new ("A") . expect (concat ! ("Unable to parse schema name " , "A")) . fully_qualified_name (enclosing_namespace) ; if named_schemas . contains (& name) { apache_avro :: schema :: Schema :: Ref { name } } else { let enclosing_namespace = & name . namespace ; named_schemas . insert (name . clone ()) ; apache_avro :: schema :: Schema :: Enum (apache_avro :: schema :: EnumSchema { name : apache_avro :: schema :: Name :: new ("A") . expect (& format ! ("Unable to parse enum name for schema {}" , "A") [..]) , aliases : None , doc : None , symbols : vec ! ["A3" . to_owned ()] , default : None , attributes : Default :: default () , }) } } fn get_record_fields_in_ctxt (mut field_position : usize , named_schemas : & mut :: std :: collections :: HashSet < :: apache_avro :: schema :: Name > , enclosing_namespace : & :: std :: option :: Option < :: std :: string :: String >) -> :: std :: option :: Option < :: std :: vec :: Vec < :: apache_avro :: schema :: RecordField >> { None } }"#;
813                let schema_token_stream = schema_res.unwrap().to_string();
814                assert_eq!(schema_token_stream, expected_token_stream);
815            }
816            Err(error) => panic!(
817                "Failed to parse as derive input when it should be able to. Error: {error:?}"
818            ),
819        };
820    }
821
822    #[test]
823    fn test_avro_rs_207_rename_all_attribute() {
824        let test_struct = quote! {
825            #[serde(rename_all="SCREAMING_SNAKE_CASE")]
826            struct A {
827                item: i32,
828                double_item: i32
829            }
830        };
831
832        match syn::parse2::<DeriveInput>(test_struct) {
833            Ok(input) => {
834                let schema_res = derive_avro_schema(input);
835                let expected_token_stream = r#"# [automatically_derived] impl :: apache_avro :: AvroSchemaComponent for A { fn get_schema_in_ctxt (named_schemas : & mut :: std :: collections :: HashSet < :: apache_avro :: schema :: Name > , enclosing_namespace : & :: std :: option :: Option < :: std :: string :: String >) -> :: apache_avro :: schema :: Schema { let name = apache_avro :: schema :: Name :: new ("A") . expect (concat ! ("Unable to parse schema name " , "A")) . fully_qualified_name (enclosing_namespace) ; if named_schemas . contains (& name) { apache_avro :: schema :: Schema :: Ref { name } } else { let enclosing_namespace = & name . namespace ; named_schemas . insert (name . clone ()) ; { let mut schema_fields = Vec :: with_capacity (2usize) ; let mut field_position = 0 ; schema_fields . push (:: apache_avro :: schema :: RecordField { name : "ITEM" . to_string () , doc : None , default : None , aliases : None , schema : < i32 as apache_avro :: AvroSchemaComponent > :: get_schema_in_ctxt (named_schemas , enclosing_namespace) , order : :: apache_avro :: schema :: RecordFieldOrder :: Ascending , position : field_position , custom_attributes : Default :: default () , }) ; field_position += 1 ; schema_fields . push (:: apache_avro :: schema :: RecordField { name : "DOUBLE_ITEM" . to_string () , doc : None , default : None , aliases : None , schema : < i32 as apache_avro :: AvroSchemaComponent > :: get_schema_in_ctxt (named_schemas , enclosing_namespace) , order : :: apache_avro :: schema :: RecordFieldOrder :: Ascending , position : field_position , custom_attributes : Default :: default () , }) ; field_position += 1 ; let schema_field_set : :: std :: collections :: HashSet < _ > = schema_fields . iter () . map (| rf | & rf . name) . collect () ; assert_eq ! (schema_fields . len () , schema_field_set . len () , "Duplicate field names found: {schema_fields:?}") ; let name = apache_avro :: schema :: Name :: new ("A") . expect (& format ! ("Unable to parse struct name for schema {}" , "A") [..]) ; let lookup : std :: collections :: BTreeMap < String , usize > = schema_fields . iter () . map (| field | (field . name . to_owned () , field . position)) . collect () ; apache_avro :: schema :: Schema :: Record (apache_avro :: schema :: RecordSchema { name , aliases : None , doc : None , fields : schema_fields , lookup , attributes : Default :: default () , }) } } } fn get_record_fields_in_ctxt (mut field_position : usize , named_schemas : & mut :: std :: collections :: HashSet < :: apache_avro :: schema :: Name > , enclosing_namespace : & :: std :: option :: Option < :: std :: string :: String >) -> :: std :: option :: Option < :: std :: vec :: Vec < :: apache_avro :: schema :: RecordField >> { let mut schema_fields = Vec :: with_capacity (2usize) ; schema_fields . push (:: apache_avro :: schema :: RecordField { name : "ITEM" . to_string () , doc : None , default : None , aliases : None , schema : < i32 as apache_avro :: AvroSchemaComponent > :: get_schema_in_ctxt (named_schemas , enclosing_namespace) , order : :: apache_avro :: schema :: RecordFieldOrder :: Ascending , position : field_position , custom_attributes : Default :: default () , }) ; field_position += 1 ; schema_fields . push (:: apache_avro :: schema :: RecordField { name : "DOUBLE_ITEM" . to_string () , doc : None , default : None , aliases : None , schema : < i32 as apache_avro :: AvroSchemaComponent > :: get_schema_in_ctxt (named_schemas , enclosing_namespace) , order : :: apache_avro :: schema :: RecordFieldOrder :: Ascending , position : field_position , custom_attributes : Default :: default () , }) ; field_position += 1 ; Some (schema_fields) } }"#;
836                let schema_token_stream = schema_res.unwrap().to_string();
837                assert_eq!(schema_token_stream, expected_token_stream);
838            }
839            Err(error) => panic!(
840                "Failed to parse as derive input when it should be able to. Error: {error:?}"
841            ),
842        };
843
844        let test_enum = quote! {
845            #[serde(rename_all="SCREAMING_SNAKE_CASE")]
846            enum B {
847                Item,
848                DoubleItem,
849            }
850        };
851
852        match syn::parse2::<DeriveInput>(test_enum) {
853            Ok(input) => {
854                let schema_res = derive_avro_schema(input);
855                let expected_token_stream = r#"# [automatically_derived] impl :: apache_avro :: AvroSchemaComponent for B { fn get_schema_in_ctxt (named_schemas : & mut :: std :: collections :: HashSet < :: apache_avro :: schema :: Name > , enclosing_namespace : & :: std :: option :: Option < :: std :: string :: String >) -> :: apache_avro :: schema :: Schema { let name = apache_avro :: schema :: Name :: new ("B") . expect (concat ! ("Unable to parse schema name " , "B")) . fully_qualified_name (enclosing_namespace) ; if named_schemas . contains (& name) { apache_avro :: schema :: Schema :: Ref { name } } else { let enclosing_namespace = & name . namespace ; named_schemas . insert (name . clone ()) ; apache_avro :: schema :: Schema :: Enum (apache_avro :: schema :: EnumSchema { name : apache_avro :: schema :: Name :: new ("B") . expect (& format ! ("Unable to parse enum name for schema {}" , "B") [..]) , aliases : None , doc : None , symbols : vec ! ["ITEM" . to_owned () , "DOUBLE_ITEM" . to_owned ()] , default : None , attributes : Default :: default () , }) } } fn get_record_fields_in_ctxt (mut field_position : usize , named_schemas : & mut :: std :: collections :: HashSet < :: apache_avro :: schema :: Name > , enclosing_namespace : & :: std :: option :: Option < :: std :: string :: String >) -> :: std :: option :: Option < :: std :: vec :: Vec < :: apache_avro :: schema :: RecordField >> { None } }"#;
856                let schema_token_stream = schema_res.unwrap().to_string();
857                assert_eq!(schema_token_stream, expected_token_stream);
858            }
859            Err(error) => panic!(
860                "Failed to parse as derive input when it should be able to. Error: {error:?}"
861            ),
862        };
863    }
864
865    #[test]
866    fn test_avro_rs_207_rename_attr_has_priority_over_rename_all_attribute() {
867        let test_struct = quote! {
868            #[serde(rename_all="SCREAMING_SNAKE_CASE")]
869            struct A {
870                item: i32,
871                #[serde(rename="DoubleItem")]
872                double_item: i32
873            }
874        };
875
876        match syn::parse2::<DeriveInput>(test_struct) {
877            Ok(input) => {
878                let schema_res = derive_avro_schema(input);
879                let expected_token_stream = r#"# [automatically_derived] impl :: apache_avro :: AvroSchemaComponent for A { fn get_schema_in_ctxt (named_schemas : & mut :: std :: collections :: HashSet < :: apache_avro :: schema :: Name > , enclosing_namespace : & :: std :: option :: Option < :: std :: string :: String >) -> :: apache_avro :: schema :: Schema { let name = apache_avro :: schema :: Name :: new ("A") . expect (concat ! ("Unable to parse schema name " , "A")) . fully_qualified_name (enclosing_namespace) ; if named_schemas . contains (& name) { apache_avro :: schema :: Schema :: Ref { name } } else { let enclosing_namespace = & name . namespace ; named_schemas . insert (name . clone ()) ; { let mut schema_fields = Vec :: with_capacity (2usize) ; let mut field_position = 0 ; schema_fields . push (:: apache_avro :: schema :: RecordField { name : "ITEM" . to_string () , doc : None , default : None , aliases : None , schema : < i32 as apache_avro :: AvroSchemaComponent > :: get_schema_in_ctxt (named_schemas , enclosing_namespace) , order : :: apache_avro :: schema :: RecordFieldOrder :: Ascending , position : field_position , custom_attributes : Default :: default () , }) ; field_position += 1 ; schema_fields . push (:: apache_avro :: schema :: RecordField { name : "DoubleItem" . to_string () , doc : None , default : None , aliases : None , schema : < i32 as apache_avro :: AvroSchemaComponent > :: get_schema_in_ctxt (named_schemas , enclosing_namespace) , order : :: apache_avro :: schema :: RecordFieldOrder :: Ascending , position : field_position , custom_attributes : Default :: default () , }) ; field_position += 1 ; let schema_field_set : :: std :: collections :: HashSet < _ > = schema_fields . iter () . map (| rf | & rf . name) . collect () ; assert_eq ! (schema_fields . len () , schema_field_set . len () , "Duplicate field names found: {schema_fields:?}") ; let name = apache_avro :: schema :: Name :: new ("A") . expect (& format ! ("Unable to parse struct name for schema {}" , "A") [..]) ; let lookup : std :: collections :: BTreeMap < String , usize > = schema_fields . iter () . map (| field | (field . name . to_owned () , field . position)) . collect () ; apache_avro :: schema :: Schema :: Record (apache_avro :: schema :: RecordSchema { name , aliases : None , doc : None , fields : schema_fields , lookup , attributes : Default :: default () , }) } } } fn get_record_fields_in_ctxt (mut field_position : usize , named_schemas : & mut :: std :: collections :: HashSet < :: apache_avro :: schema :: Name > , enclosing_namespace : & :: std :: option :: Option < :: std :: string :: String >) -> :: std :: option :: Option < :: std :: vec :: Vec < :: apache_avro :: schema :: RecordField >> { let mut schema_fields = Vec :: with_capacity (2usize) ; schema_fields . push (:: apache_avro :: schema :: RecordField { name : "ITEM" . to_string () , doc : None , default : None , aliases : None , schema : < i32 as apache_avro :: AvroSchemaComponent > :: get_schema_in_ctxt (named_schemas , enclosing_namespace) , order : :: apache_avro :: schema :: RecordFieldOrder :: Ascending , position : field_position , custom_attributes : Default :: default () , }) ; field_position += 1 ; schema_fields . push (:: apache_avro :: schema :: RecordField { name : "DoubleItem" . to_string () , doc : None , default : None , aliases : None , schema : < i32 as apache_avro :: AvroSchemaComponent > :: get_schema_in_ctxt (named_schemas , enclosing_namespace) , order : :: apache_avro :: schema :: RecordFieldOrder :: Ascending , position : field_position , custom_attributes : Default :: default () , }) ; field_position += 1 ; Some (schema_fields) } }"#;
880                let schema_token_stream = schema_res.unwrap().to_string();
881                assert_eq!(schema_token_stream, expected_token_stream);
882            }
883            Err(error) => panic!(
884                "Failed to parse as derive input when it should be able to. Error: {error:?}"
885            ),
886        };
887    }
888}