apache_avro_derive/
lib.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18#![cfg_attr(nightly, feature(proc_macro_diagnostic))]
19
20//! This crate is the implementation of the `AvroSchema` derive macro.
21//! Please use it via the [`apache-avro`](https://crates.io/crates/apache-avro) crate:
22//!
23//! ```no_run
24//! use apache_avro::AvroSchema;
25//!
26//! #[derive(AvroSchema)]
27//! ```
28//! Please see the documentation of the [`AvroSchema`] trait for instructions on how to use it.
29//!
30//! [`AvroSchema`]: https://docs.rs/apache-avro/latest/apache_avro/serde/trait.AvroSchema.html
31
32mod attributes;
33mod case;
34
35use proc_macro2::{Span, TokenStream};
36use quote::quote;
37use syn::{
38    Attribute, DataEnum, DataStruct, DeriveInput, Expr, Field, Fields, Generics, Ident, Meta, Type,
39    parse_macro_input, spanned::Spanned,
40};
41
42use crate::{
43    attributes::{FieldOptions, NamedTypeOptions, VariantOptions, With},
44    case::RenameRule,
45};
46
47#[proc_macro_derive(AvroSchema, attributes(avro, serde))]
48// Templated from Serde
49pub fn proc_macro_derive_avro_schema(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
50    let input = parse_macro_input!(input as DeriveInput);
51    derive_avro_schema(input)
52        .unwrap_or_else(to_compile_errors)
53        .into()
54}
55
56fn derive_avro_schema(input: DeriveInput) -> Result<TokenStream, Vec<syn::Error>> {
57    // It would be nice to parse the attributes before the `match`, but we first need to validate that `input` is not a union.
58    // Otherwise a user could get errors related to the attributes and after fixing those get an error because the attributes were on a union.
59    let input_span = input.span();
60    match input.data {
61        syn::Data::Struct(data_struct) => {
62            let named_type_options = NamedTypeOptions::new(&input.ident, &input.attrs, input_span)?;
63            let (get_schema_impl, get_record_fields_impl) = if named_type_options.transparent {
64                get_transparent_struct_schema_def(data_struct.fields, input_span)?
65            } else {
66                let (schema_def, record_fields) =
67                    get_struct_schema_def(&named_type_options, data_struct, input.ident.span())?;
68                (
69                    handle_named_schemas(named_type_options.name, schema_def),
70                    record_fields,
71                )
72            };
73            Ok(create_trait_definition(
74                input.ident,
75                &input.generics,
76                get_schema_impl,
77                get_record_fields_impl,
78            ))
79        }
80        syn::Data::Enum(data_enum) => {
81            let named_type_options = NamedTypeOptions::new(&input.ident, &input.attrs, input_span)?;
82            if named_type_options.transparent {
83                return Err(vec![syn::Error::new(
84                    input_span,
85                    "AvroSchema: `#[serde(transparent)]` is only supported on structs",
86                )]);
87            }
88            let schema_def =
89                get_data_enum_schema_def(&named_type_options, data_enum, input.ident.span())?;
90            let inner = handle_named_schemas(named_type_options.name, schema_def);
91            Ok(create_trait_definition(
92                input.ident,
93                &input.generics,
94                inner,
95                quote! { None },
96            ))
97        }
98        syn::Data::Union(_) => Err(vec![syn::Error::new(
99            input_span,
100            "AvroSchema: derive only works for structs and simple enums",
101        )]),
102    }
103}
104
105/// Generate the trait definition with the correct generics
106fn create_trait_definition(
107    ident: Ident,
108    generics: &Generics,
109    get_schema_impl: TokenStream,
110    get_record_fields_impl: TokenStream,
111) -> TokenStream {
112    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
113    quote! {
114        #[automatically_derived]
115        impl #impl_generics ::apache_avro::AvroSchemaComponent for #ident #ty_generics #where_clause {
116            fn get_schema_in_ctxt(named_schemas: &mut ::apache_avro::schema::Names, enclosing_namespace: &::std::option::Option<::std::string::String>) -> ::apache_avro::schema::Schema {
117                #get_schema_impl
118            }
119
120            fn get_record_fields_in_ctxt(mut field_position: usize, named_schemas: &mut ::apache_avro::schema::Names, enclosing_namespace: &::std::option::Option<::std::string::String>) -> ::std::option::Option<::std::vec::Vec<::apache_avro::schema::RecordField>> {
121                #get_record_fields_impl
122            }
123        }
124    }
125}
126
127/// Generate the code to check `named_schemas` if this schema already exist
128fn handle_named_schemas(full_schema_name: String, schema_def: TokenStream) -> TokenStream {
129    quote! {
130        let name = apache_avro::schema::Name::new(#full_schema_name).expect(concat!("Unable to parse schema name ", #full_schema_name)).fully_qualified_name(enclosing_namespace);
131        if named_schemas.contains_key(&name) {
132            apache_avro::schema::Schema::Ref{name}
133        } else {
134            let enclosing_namespace = &name.namespace;
135            // This is needed because otherwise recursive types will recurse forever and cause a stack overflow
136            // TODO: Breaking change to AvroSchemaComponent, have named_schemas be a set instead
137            named_schemas.insert(name.clone(), apache_avro::schema::Schema::Ref{name: name.clone()});
138            let schema = #schema_def;
139            named_schemas.insert(name, schema.clone());
140            schema
141        }
142    }
143}
144
145/// Generate a schema definition for a struct.
146fn get_struct_schema_def(
147    container_attrs: &NamedTypeOptions,
148    data_struct: DataStruct,
149    ident_span: Span,
150) -> Result<(TokenStream, TokenStream), Vec<syn::Error>> {
151    let mut record_field_exprs = vec![];
152    match data_struct.fields {
153        Fields::Named(a) => {
154            for field in a.named {
155                let mut name = field
156                    .ident
157                    .as_ref()
158                    .expect("Field must have a name")
159                    .to_string();
160                if let Some(raw_name) = name.strip_prefix("r#") {
161                    name = raw_name.to_string();
162                }
163                let field_attrs = FieldOptions::new(&field.attrs, field.span())?;
164                let doc = preserve_optional(field_attrs.doc);
165                match (field_attrs.rename, container_attrs.rename_all) {
166                    (Some(rename), _) => {
167                        name = rename;
168                    }
169                    (None, rename_all) if rename_all != RenameRule::None => {
170                        name = rename_all.apply_to_field(&name);
171                    }
172                    _ => {}
173                }
174                if field_attrs.skip {
175                    continue;
176                } else if field_attrs.flatten {
177                    // Inline the fields of the child record at runtime, as we don't have access to
178                    // the schema here.
179                    let get_record_fields =
180                        get_field_get_record_fields_expr(&field, field_attrs.with)?;
181                    record_field_exprs.push(quote! {
182                        if let Some(flattened_fields) = #get_record_fields {
183                            field_position += flattened_fields.len();
184                            schema_fields.extend(flattened_fields);
185                        } else {
186                            panic!("{} does not have any fields to flatten to", stringify!(#field));
187                        }
188                    });
189
190                    // Don't add this field as it's been replaced by the child record fields
191                    continue;
192                }
193                let default_value = match field_attrs.default {
194                    Some(default_value) => {
195                        let _: serde_json::Value = serde_json::from_str(&default_value[..])
196                            .map_err(|e| {
197                                vec![syn::Error::new(
198                                    field.ident.span(),
199                                    format!("Invalid avro default json: \n{e}"),
200                                )]
201                            })?;
202                        quote! {
203                            Some(serde_json::from_str(#default_value).expect(format!("Invalid JSON: {:?}", #default_value).as_str()))
204                        }
205                    }
206                    None => quote! { None },
207                };
208                let aliases = aliases(&field_attrs.alias);
209                let schema_expr = get_field_schema_expr(&field, field_attrs.with)?;
210                record_field_exprs.push(quote! {
211                    schema_fields.push(::apache_avro::schema::RecordField {
212                        name: #name.to_string(),
213                        doc: #doc,
214                        default: #default_value,
215                        aliases: #aliases,
216                        schema: #schema_expr,
217                        order: ::apache_avro::schema::RecordFieldOrder::Ascending,
218                        position: field_position,
219                        custom_attributes: Default::default(),
220                    });
221                    field_position += 1;
222                });
223            }
224        }
225        Fields::Unnamed(_) => {
226            return Err(vec![syn::Error::new(
227                ident_span,
228                "AvroSchema derive does not work for tuple structs",
229            )]);
230        }
231        Fields::Unit => {
232            return Err(vec![syn::Error::new(
233                ident_span,
234                "AvroSchema derive does not work for unit structs",
235            )]);
236        }
237    }
238
239    let record_doc = preserve_optional(container_attrs.doc.as_ref());
240    let record_aliases = aliases(&container_attrs.aliases);
241    let full_schema_name = &container_attrs.name;
242
243    // When flatten is involved, there will be more but we don't know how many. This optimises for
244    // the most common case where there is no flatten.
245    let minimum_fields = record_field_exprs.len();
246
247    let schema_def = quote! {
248        {
249            let mut schema_fields = Vec::with_capacity(#minimum_fields);
250            let mut field_position = 0;
251            #(#record_field_exprs)*
252            let schema_field_set: ::std::collections::HashSet<_> = schema_fields.iter().map(|rf| &rf.name).collect();
253            assert_eq!(schema_fields.len(), schema_field_set.len(), "Duplicate field names found: {schema_fields:?}");
254            let name = apache_avro::schema::Name::new(#full_schema_name).expect(&format!("Unable to parse struct name for schema {}", #full_schema_name)[..]);
255            let lookup: std::collections::BTreeMap<String, usize> = schema_fields
256                .iter()
257                .map(|field| (field.name.to_owned(), field.position))
258                .collect();
259            apache_avro::schema::Schema::Record(apache_avro::schema::RecordSchema {
260                name,
261                aliases: #record_aliases,
262                doc: #record_doc,
263                fields: schema_fields,
264                lookup,
265                attributes: Default::default(),
266            })
267        }
268    };
269    let record_fields = quote! {
270        let mut schema_fields = Vec::with_capacity(#minimum_fields);
271        #(#record_field_exprs)*
272        Some(schema_fields)
273    };
274
275    Ok((schema_def, record_fields))
276}
277
278/// Use the schema definition of the only field in the struct as the schema
279fn get_transparent_struct_schema_def(
280    fields: Fields,
281    input_span: Span,
282) -> Result<(TokenStream, TokenStream), Vec<syn::Error>> {
283    match fields {
284        Fields::Named(fields_named) => {
285            let mut found = None;
286            for field in fields_named.named {
287                let attrs = FieldOptions::new(&field.attrs, field.span())?;
288                if attrs.skip {
289                    continue;
290                }
291                if found.replace((field, attrs)).is_some() {
292                    return Err(vec![syn::Error::new(
293                        input_span,
294                        "AvroSchema: #[serde(transparent)] is only allowed on structs with one unskipped field",
295                    )]);
296                }
297            }
298
299            if let Some((field, attrs)) = found {
300                Ok((
301                    get_field_schema_expr(&field, attrs.with.clone())?,
302                    get_field_get_record_fields_expr(&field, attrs.with)?,
303                ))
304            } else {
305                Err(vec![syn::Error::new(
306                    input_span,
307                    "AvroSchema: #[serde(transparent)] is only allowed on structs with one unskipped field",
308                )])
309            }
310        }
311        Fields::Unnamed(_) => Err(vec![syn::Error::new(
312            input_span,
313            "AvroSchema: derive does not work for tuple structs",
314        )]),
315        Fields::Unit => Err(vec![syn::Error::new(
316            input_span,
317            "AvroSchema: derive does not work for unit structs",
318        )]),
319    }
320}
321
322fn get_field_schema_expr(field: &Field, with: With) -> Result<TokenStream, Vec<syn::Error>> {
323    match with {
324        With::Trait => Ok(type_to_schema_expr(&field.ty)?),
325        With::Serde(path) => {
326            Ok(quote! { #path::get_schema_in_ctxt(named_schemas, enclosing_namespace) })
327        }
328        With::Expr(Expr::Closure(closure)) => {
329            if closure.inputs.is_empty() {
330                Ok(quote! { (#closure)() })
331            } else {
332                Err(vec![syn::Error::new(
333                    field.span(),
334                    "Expected closure with 0 parameters",
335                )])
336            }
337        }
338        With::Expr(Expr::Path(path)) => Ok(quote! { #path(named_schemas, enclosing_namespace) }),
339        With::Expr(_expr) => Err(vec![syn::Error::new(
340            field.span(),
341            "Invalid expression, expected function or closure",
342        )]),
343    }
344}
345
346fn get_field_get_record_fields_expr(
347    field: &Field,
348    with: With,
349) -> Result<TokenStream, Vec<syn::Error>> {
350    match with {
351        With::Trait => Ok(type_to_get_record_fields_expr(&field.ty)?),
352        With::Serde(path) => Ok(
353            quote! { #path::get_record_fields_in_ctxt(field_position, named_schemas, enclosing_namespace) },
354        ),
355        With::Expr(Expr::Closure(closure)) => {
356            if closure.inputs.is_empty() {
357                Ok(quote! {
358                    ::apache_avro::serde::get_record_fields_in_ctxt(
359                        field_position,
360                        named_schemas,
361                        enclosing_namespace,
362                        |_, _| (#closure)(),
363                    )
364                })
365            } else {
366                Err(vec![syn::Error::new(
367                    field.span(),
368                    "Expected closure with 0 parameters",
369                )])
370            }
371        }
372        With::Expr(Expr::Path(path)) => Ok(quote! {
373            ::apache_avro::serde::get_record_fields_in_ctxt(field_position, named_schemas, enclosing_namespace, #path)
374        }),
375        With::Expr(_expr) => Err(vec![syn::Error::new(
376            field.span(),
377            "Invalid expression, expected function or closure",
378        )]),
379    }
380}
381
382/// Generate a schema definition for a enum.
383fn get_data_enum_schema_def(
384    container_attrs: &NamedTypeOptions,
385    data_enum: DataEnum,
386    ident_span: Span,
387) -> Result<TokenStream, Vec<syn::Error>> {
388    let doc = preserve_optional(container_attrs.doc.as_ref());
389    let enum_aliases = aliases(&container_attrs.aliases);
390    if data_enum.variants.iter().all(|v| Fields::Unit == v.fields) {
391        let default_value = default_enum_variant(&data_enum, ident_span)?;
392        let default = preserve_optional(default_value);
393        let mut symbols = Vec::new();
394        for variant in &data_enum.variants {
395            let field_attrs = VariantOptions::new(&variant.attrs, variant.span())?;
396            let name = match (field_attrs.rename, container_attrs.rename_all) {
397                (Some(rename), _) => rename,
398                (None, rename_all) if !matches!(rename_all, RenameRule::None) => {
399                    rename_all.apply_to_variant(&variant.ident.to_string())
400                }
401                _ => variant.ident.to_string(),
402            };
403            symbols.push(name);
404        }
405        let full_schema_name = &container_attrs.name;
406        Ok(quote! {
407            apache_avro::schema::Schema::Enum(apache_avro::schema::EnumSchema {
408                name: apache_avro::schema::Name::new(#full_schema_name).expect(&format!("Unable to parse enum name for schema {}", #full_schema_name)[..]),
409                aliases: #enum_aliases,
410                doc: #doc,
411                symbols: vec![#(#symbols.to_owned()),*],
412                default: #default,
413                attributes: Default::default(),
414            })
415        })
416    } else {
417        Err(vec![syn::Error::new(
418            ident_span,
419            "AvroSchema: derive does not work for enums with non unit structs",
420        )])
421    }
422}
423
424/// Takes in the Tokens of a type and returns the tokens of an expression with return type `Schema`
425fn type_to_schema_expr(ty: &Type) -> Result<TokenStream, Vec<syn::Error>> {
426    match ty {
427        Type::Array(_) | Type::Slice(_) | Type::Path(_) | Type::Reference(_) => Ok(
428            quote! {<#ty as apache_avro::AvroSchemaComponent>::get_schema_in_ctxt(named_schemas, enclosing_namespace)},
429        ),
430        Type::Ptr(_) => Err(vec![syn::Error::new_spanned(
431            ty,
432            "AvroSchema: derive does not support raw pointers",
433        )]),
434        Type::Tuple(_) => Err(vec![syn::Error::new_spanned(
435            ty,
436            "AvroSchema: derive does not support tuples",
437        )]),
438        _ => Err(vec![syn::Error::new_spanned(
439            ty,
440            format!(
441                "AvroSchema: Unexpected type encountered! Please open an issue if this kind of type should be supported: {ty:?}"
442            ),
443        )]),
444    }
445}
446
447fn type_to_get_record_fields_expr(ty: &Type) -> Result<TokenStream, Vec<syn::Error>> {
448    match ty {
449        Type::Array(_) | Type::Slice(_) | Type::Path(_) | Type::Reference(_) => Ok(
450            quote! {<#ty as apache_avro::AvroSchemaComponent>::get_record_fields_in_ctxt(field_position, named_schemas, enclosing_namespace)},
451        ),
452        Type::Ptr(_) => Err(vec![syn::Error::new_spanned(
453            ty,
454            "AvroSchema: derive does not support raw pointers",
455        )]),
456        Type::Tuple(_) => Err(vec![syn::Error::new_spanned(
457            ty,
458            "AvroSchema: derive does not support tuples",
459        )]),
460        _ => Err(vec![syn::Error::new_spanned(
461            ty,
462            format!(
463                "AvroSchema: Unexpected type encountered! Please open an issue if this kind of type should be supported: {ty:?}"
464            ),
465        )]),
466    }
467}
468
469fn default_enum_variant(
470    data_enum: &syn::DataEnum,
471    error_span: Span,
472) -> Result<Option<String>, Vec<syn::Error>> {
473    match data_enum
474        .variants
475        .iter()
476        .filter(|v| v.attrs.iter().any(is_default_attr))
477        .collect::<Vec<_>>()
478    {
479        variants if variants.is_empty() => Ok(None),
480        single if single.len() == 1 => Ok(Some(single[0].ident.to_string())),
481        multiple => Err(vec![syn::Error::new(
482            error_span,
483            format!(
484                "Multiple defaults defined: {:?}",
485                multiple
486                    .iter()
487                    .map(|v| v.ident.to_string())
488                    .collect::<Vec<String>>()
489            ),
490        )]),
491    }
492}
493
494fn is_default_attr(attr: &Attribute) -> bool {
495    matches!(attr, Attribute { meta: Meta::Path(path), .. } if path.get_ident().map(Ident::to_string).as_deref() == Some("default"))
496}
497
498/// Stolen from serde
499fn to_compile_errors(errors: Vec<syn::Error>) -> proc_macro2::TokenStream {
500    let compile_errors = errors.iter().map(syn::Error::to_compile_error);
501    quote!(#(#compile_errors)*)
502}
503
504fn preserve_optional(op: Option<impl quote::ToTokens>) -> TokenStream {
505    match op {
506        Some(tt) => quote! {Some(#tt.into())},
507        None => quote! {None},
508    }
509}
510
511fn aliases(op: &[impl quote::ToTokens]) -> TokenStream {
512    let items: Vec<TokenStream> = op
513        .iter()
514        .map(|tt| quote! {#tt.try_into().expect("Alias is invalid")})
515        .collect();
516    if items.is_empty() {
517        quote! {None}
518    } else {
519        quote! {Some(vec![#(#items),*])}
520    }
521}
522
523#[cfg(test)]
524mod tests {
525    use super::*;
526    use pretty_assertions::assert_eq;
527
528    #[test]
529    fn basic_case() {
530        let test_struct = quote! {
531            struct A {
532                a: i32,
533                b: String
534            }
535        };
536
537        match syn::parse2::<DeriveInput>(test_struct) {
538            Ok(input) => {
539                assert!(derive_avro_schema(input).is_ok())
540            }
541            Err(error) => panic!(
542                "Failed to parse as derive input when it should be able to. Error: {error:?}"
543            ),
544        };
545    }
546
547    #[test]
548    fn tuple_struct_unsupported() {
549        let test_tuple_struct = quote! {
550            struct B (i32, String);
551        };
552
553        match syn::parse2::<DeriveInput>(test_tuple_struct) {
554            Ok(input) => {
555                assert!(derive_avro_schema(input).is_err())
556            }
557            Err(error) => panic!(
558                "Failed to parse as derive input when it should be able to. Error: {error:?}"
559            ),
560        };
561    }
562
563    #[test]
564    fn unit_struct_unsupported() {
565        let test_tuple_struct = quote! {
566            struct AbsoluteUnit;
567        };
568
569        match syn::parse2::<DeriveInput>(test_tuple_struct) {
570            Ok(input) => {
571                assert!(derive_avro_schema(input).is_err())
572            }
573            Err(error) => panic!(
574                "Failed to parse as derive input when it should be able to. Error: {error:?}"
575            ),
576        };
577    }
578
579    #[test]
580    fn struct_with_optional() {
581        let struct_with_optional = quote! {
582            struct Test4 {
583                a : Option<i32>
584            }
585        };
586        match syn::parse2::<DeriveInput>(struct_with_optional) {
587            Ok(input) => {
588                assert!(derive_avro_schema(input).is_ok())
589            }
590            Err(error) => panic!(
591                "Failed to parse as derive input when it should be able to. Error: {error:?}"
592            ),
593        };
594    }
595
596    #[test]
597    fn test_basic_enum() {
598        let basic_enum = quote! {
599            enum Basic {
600                A,
601                B,
602                C,
603                D
604            }
605        };
606        match syn::parse2::<DeriveInput>(basic_enum) {
607            Ok(input) => {
608                assert!(derive_avro_schema(input).is_ok())
609            }
610            Err(error) => panic!(
611                "Failed to parse as derive input when it should be able to. Error: {error:?}"
612            ),
613        };
614    }
615
616    #[test]
617    fn avro_3687_basic_enum_with_default() {
618        let basic_enum = quote! {
619            enum Basic {
620                #[default]
621                A,
622                B,
623                C,
624                D
625            }
626        };
627        match syn::parse2::<DeriveInput>(basic_enum) {
628            Ok(input) => {
629                let derived = derive_avro_schema(input);
630                assert!(derived.is_ok());
631                assert_eq!(derived.unwrap().to_string(), quote! {
632                    #[automatically_derived]
633                    impl ::apache_avro::AvroSchemaComponent for Basic {
634                        fn get_schema_in_ctxt(
635                            named_schemas: &mut ::apache_avro::schema::Names,
636                            enclosing_namespace: &::std::option::Option<::std::string::String>
637                        ) -> ::apache_avro::schema::Schema {
638                            let name = apache_avro::schema::Name::new("Basic")
639                                .expect(concat!("Unable to parse schema name ", "Basic"))
640                                .fully_qualified_name(enclosing_namespace);
641                            if named_schemas.contains_key(&name) {
642                                apache_avro::schema::Schema::Ref { name }
643                            } else {
644                                let enclosing_namespace = &name.namespace;
645                                named_schemas.insert(name.clone(), apache_avro::schema::Schema::Ref{name: name.clone()});
646                                let schema =
647                                apache_avro::schema::Schema::Enum(apache_avro::schema::EnumSchema {
648                                    name: apache_avro::schema::Name::new("Basic").expect(
649                                        &format!("Unable to parse enum name for schema {}", "Basic")[..]
650                                    ),
651                                    aliases: None,
652                                    doc: None,
653                                    symbols: vec![
654                                        "A".to_owned(),
655                                        "B".to_owned(),
656                                        "C".to_owned(),
657                                        "D".to_owned()
658                                    ],
659                                    default: Some("A".into()),
660                                    attributes: Default::default(),
661                                });
662                                named_schemas.insert(name, schema.clone());
663                                schema
664                            }
665                        }
666
667                        fn get_record_fields_in_ctxt(
668                            mut field_position: usize,
669                            named_schemas: &mut ::apache_avro::schema::Names,
670                            enclosing_namespace: &::std::option::Option<::std::string::String>
671                        ) -> ::std::option::Option <::std::vec::Vec<::apache_avro::schema::RecordField>> {
672                            None
673                        }
674                    }
675                }.to_string());
676            }
677            Err(error) => panic!(
678                "Failed to parse as derive input when it should be able to. Error: {error:?}"
679            ),
680        };
681    }
682
683    #[test]
684    fn avro_3687_basic_enum_with_default_twice() {
685        let non_basic_enum = quote! {
686            enum Basic {
687                #[default]
688                A,
689                B,
690                #[default]
691                C,
692                D
693            }
694        };
695        match syn::parse2::<DeriveInput>(non_basic_enum) {
696            Ok(input) => match derive_avro_schema(input) {
697                Ok(_) => {
698                    panic!("Should not be able to derive schema for enum with multiple defaults")
699                }
700                Err(errors) => {
701                    assert_eq!(errors.len(), 1);
702                    assert_eq!(
703                        errors[0].to_string(),
704                        r#"Multiple defaults defined: ["A", "C"]"#
705                    );
706                }
707            },
708            Err(error) => panic!(
709                "Failed to parse as derive input when it should be able to. Error: {error:?}"
710            ),
711        };
712    }
713
714    #[test]
715    fn test_non_basic_enum() {
716        let non_basic_enum = quote! {
717            enum Basic {
718                A(i32),
719                B,
720                C,
721                D
722            }
723        };
724        match syn::parse2::<DeriveInput>(non_basic_enum) {
725            Ok(input) => {
726                assert!(derive_avro_schema(input).is_err())
727            }
728            Err(error) => panic!(
729                "Failed to parse as derive input when it should be able to. Error: {error:?}"
730            ),
731        };
732    }
733
734    #[test]
735    fn test_namespace() {
736        let test_struct = quote! {
737            #[avro(namespace = "namespace.testing")]
738            struct A {
739                a: i32,
740                b: String
741            }
742        };
743
744        match syn::parse2::<DeriveInput>(test_struct) {
745            Ok(input) => {
746                let schema_token_stream = derive_avro_schema(input);
747                assert!(&schema_token_stream.is_ok());
748                assert!(
749                    schema_token_stream
750                        .unwrap()
751                        .to_string()
752                        .contains("namespace.testing")
753                )
754            }
755            Err(error) => panic!(
756                "Failed to parse as derive input when it should be able to. Error: {error:?}"
757            ),
758        };
759    }
760
761    #[test]
762    fn test_reference() {
763        let test_reference_struct = quote! {
764            struct A<'a> {
765                a: &'a Vec<i32>,
766                b: &'static str
767            }
768        };
769
770        match syn::parse2::<DeriveInput>(test_reference_struct) {
771            Ok(input) => {
772                assert!(derive_avro_schema(input).is_ok())
773            }
774            Err(error) => panic!(
775                "Failed to parse as derive input when it should be able to. Error: {error:?}"
776            ),
777        };
778    }
779
780    #[test]
781    fn test_trait_cast() {
782        assert_eq!(type_to_schema_expr(&syn::parse2::<Type>(quote!{i32}).unwrap()).unwrap().to_string(), quote!{<i32 as apache_avro::AvroSchemaComponent>::get_schema_in_ctxt(named_schemas, enclosing_namespace)}.to_string());
783        assert_eq!(type_to_schema_expr(&syn::parse2::<Type>(quote!{Vec<T>}).unwrap()).unwrap().to_string(), quote!{<Vec<T> as apache_avro::AvroSchemaComponent>::get_schema_in_ctxt(named_schemas, enclosing_namespace)}.to_string());
784        assert_eq!(type_to_schema_expr(&syn::parse2::<Type>(quote!{AnyType}).unwrap()).unwrap().to_string(), quote!{<AnyType as apache_avro::AvroSchemaComponent>::get_schema_in_ctxt(named_schemas, enclosing_namespace)}.to_string());
785    }
786
787    #[test]
788    fn test_avro_3709_record_field_attributes() {
789        let test_struct = quote! {
790            struct A {
791                #[serde(alias = "a1", alias = "a2", rename = "a3")]
792                #[avro(doc = "a doc", default = "123")]
793                a: i32
794            }
795        };
796
797        match syn::parse2::<DeriveInput>(test_struct) {
798            Ok(input) => {
799                let schema_res = derive_avro_schema(input);
800                let expected_token_stream = r#"# [automatically_derived] impl :: apache_avro :: AvroSchemaComponent for A { fn get_schema_in_ctxt (named_schemas : & mut :: apache_avro :: schema :: Names , enclosing_namespace : & :: std :: option :: Option < :: std :: string :: String >) -> :: apache_avro :: schema :: Schema { let name = apache_avro :: schema :: Name :: new ("A") . expect (concat ! ("Unable to parse schema name " , "A")) . fully_qualified_name (enclosing_namespace) ; if named_schemas . contains_key (& name) { apache_avro :: schema :: Schema :: Ref { name } } else { let enclosing_namespace = & name . namespace ; named_schemas . insert (name . clone () , apache_avro :: schema :: Schema :: Ref { name : name . clone () }) ; let schema = { let mut schema_fields = Vec :: with_capacity (1usize) ; let mut field_position = 0 ; schema_fields . push (:: apache_avro :: schema :: RecordField { name : "a3" . to_string () , doc : Some ("a doc" . into ()) , default : Some (serde_json :: from_str ("123") . expect (format ! ("Invalid JSON: {:?}" , "123") . as_str ())) , aliases : Some (vec ! ["a1" . try_into () . expect ("Alias is invalid") , "a2" . try_into () . expect ("Alias is invalid")]) , schema : < i32 as apache_avro :: AvroSchemaComponent > :: get_schema_in_ctxt (named_schemas , enclosing_namespace) , order : :: apache_avro :: schema :: RecordFieldOrder :: Ascending , position : field_position , custom_attributes : Default :: default () , }) ; field_position += 1 ; let schema_field_set : :: std :: collections :: HashSet < _ > = schema_fields . iter () . map (| rf | & rf . name) . collect () ; assert_eq ! (schema_fields . len () , schema_field_set . len () , "Duplicate field names found: {schema_fields:?}") ; let name = apache_avro :: schema :: Name :: new ("A") . expect (& format ! ("Unable to parse struct name for schema {}" , "A") [..]) ; let lookup : std :: collections :: BTreeMap < String , usize > = schema_fields . iter () . map (| field | (field . name . to_owned () , field . position)) . collect () ; apache_avro :: schema :: Schema :: Record (apache_avro :: schema :: RecordSchema { name , aliases : None , doc : None , fields : schema_fields , lookup , attributes : Default :: default () , }) } ; named_schemas . insert (name , schema . clone ()) ; schema } } fn get_record_fields_in_ctxt (mut field_position : usize , named_schemas : & mut :: apache_avro :: schema :: Names , enclosing_namespace : & :: std :: option :: Option < :: std :: string :: String >) -> :: std :: option :: Option < :: std :: vec :: Vec < :: apache_avro :: schema :: RecordField >> { let mut schema_fields = Vec :: with_capacity (1usize) ; schema_fields . push (:: apache_avro :: schema :: RecordField { name : "a3" . to_string () , doc : Some ("a doc" . into ()) , default : Some (serde_json :: from_str ("123") . expect (format ! ("Invalid JSON: {:?}" , "123") . as_str ())) , aliases : Some (vec ! ["a1" . try_into () . expect ("Alias is invalid") , "a2" . try_into () . expect ("Alias is invalid")]) , schema : < i32 as apache_avro :: AvroSchemaComponent > :: get_schema_in_ctxt (named_schemas , enclosing_namespace) , order : :: apache_avro :: schema :: RecordFieldOrder :: Ascending , position : field_position , custom_attributes : Default :: default () , }) ; field_position += 1 ; Some (schema_fields) } }"#;
801                let schema_token_stream = schema_res.unwrap().to_string();
802                assert_eq!(schema_token_stream, expected_token_stream);
803            }
804            Err(error) => panic!(
805                "Failed to parse as derive input when it should be able to. Error: {error:?}"
806            ),
807        };
808
809        let test_enum = quote! {
810            enum A {
811                #[serde(rename = "A3")]
812                Item1,
813            }
814        };
815
816        match syn::parse2::<DeriveInput>(test_enum) {
817            Ok(input) => {
818                let schema_res = derive_avro_schema(input);
819                let expected_token_stream = r#"# [automatically_derived] impl :: apache_avro :: AvroSchemaComponent for A { fn get_schema_in_ctxt (named_schemas : & mut :: apache_avro :: schema :: Names , enclosing_namespace : & :: std :: option :: Option < :: std :: string :: String >) -> :: apache_avro :: schema :: Schema { let name = apache_avro :: schema :: Name :: new ("A") . expect (concat ! ("Unable to parse schema name " , "A")) . fully_qualified_name (enclosing_namespace) ; if named_schemas . contains_key (& name) { apache_avro :: schema :: Schema :: Ref { name } } else { let enclosing_namespace = & name . namespace ; named_schemas . insert (name . clone () , apache_avro :: schema :: Schema :: Ref { name : name . clone () }) ; let schema = apache_avro :: schema :: Schema :: Enum (apache_avro :: schema :: EnumSchema { name : apache_avro :: schema :: Name :: new ("A") . expect (& format ! ("Unable to parse enum name for schema {}" , "A") [..]) , aliases : None , doc : None , symbols : vec ! ["A3" . to_owned ()] , default : None , attributes : Default :: default () , }) ; named_schemas . insert (name , schema . clone ()) ; schema } } fn get_record_fields_in_ctxt (mut field_position : usize , named_schemas : & mut :: apache_avro :: schema :: Names , enclosing_namespace : & :: std :: option :: Option < :: std :: string :: String >) -> :: std :: option :: Option < :: std :: vec :: Vec < :: apache_avro :: schema :: RecordField >> { None } }"#;
820                let schema_token_stream = schema_res.unwrap().to_string();
821                assert_eq!(schema_token_stream, expected_token_stream);
822            }
823            Err(error) => panic!(
824                "Failed to parse as derive input when it should be able to. Error: {error:?}"
825            ),
826        };
827    }
828
829    #[test]
830    fn test_avro_rs_207_rename_all_attribute() {
831        let test_struct = quote! {
832            #[serde(rename_all="SCREAMING_SNAKE_CASE")]
833            struct A {
834                item: i32,
835                double_item: i32
836            }
837        };
838
839        match syn::parse2::<DeriveInput>(test_struct) {
840            Ok(input) => {
841                let schema_res = derive_avro_schema(input);
842                let expected_token_stream = r#"# [automatically_derived] impl :: apache_avro :: AvroSchemaComponent for A { fn get_schema_in_ctxt (named_schemas : & mut :: apache_avro :: schema :: Names , enclosing_namespace : & :: std :: option :: Option < :: std :: string :: String >) -> :: apache_avro :: schema :: Schema { let name = apache_avro :: schema :: Name :: new ("A") . expect (concat ! ("Unable to parse schema name " , "A")) . fully_qualified_name (enclosing_namespace) ; if named_schemas . contains_key (& name) { apache_avro :: schema :: Schema :: Ref { name } } else { let enclosing_namespace = & name . namespace ; named_schemas . insert (name . clone () , apache_avro :: schema :: Schema :: Ref { name : name . clone () }) ; let schema = { let mut schema_fields = Vec :: with_capacity (2usize) ; let mut field_position = 0 ; schema_fields . push (:: apache_avro :: schema :: RecordField { name : "ITEM" . to_string () , doc : None , default : None , aliases : None , schema : < i32 as apache_avro :: AvroSchemaComponent > :: get_schema_in_ctxt (named_schemas , enclosing_namespace) , order : :: apache_avro :: schema :: RecordFieldOrder :: Ascending , position : field_position , custom_attributes : Default :: default () , }) ; field_position += 1 ; schema_fields . push (:: apache_avro :: schema :: RecordField { name : "DOUBLE_ITEM" . to_string () , doc : None , default : None , aliases : None , schema : < i32 as apache_avro :: AvroSchemaComponent > :: get_schema_in_ctxt (named_schemas , enclosing_namespace) , order : :: apache_avro :: schema :: RecordFieldOrder :: Ascending , position : field_position , custom_attributes : Default :: default () , }) ; field_position += 1 ; let schema_field_set : :: std :: collections :: HashSet < _ > = schema_fields . iter () . map (| rf | & rf . name) . collect () ; assert_eq ! (schema_fields . len () , schema_field_set . len () , "Duplicate field names found: {schema_fields:?}") ; let name = apache_avro :: schema :: Name :: new ("A") . expect (& format ! ("Unable to parse struct name for schema {}" , "A") [..]) ; let lookup : std :: collections :: BTreeMap < String , usize > = schema_fields . iter () . map (| field | (field . name . to_owned () , field . position)) . collect () ; apache_avro :: schema :: Schema :: Record (apache_avro :: schema :: RecordSchema { name , aliases : None , doc : None , fields : schema_fields , lookup , attributes : Default :: default () , }) } ; named_schemas . insert (name , schema . clone ()) ; schema } } fn get_record_fields_in_ctxt (mut field_position : usize , named_schemas : & mut :: apache_avro :: schema :: Names , enclosing_namespace : & :: std :: option :: Option < :: std :: string :: String >) -> :: std :: option :: Option < :: std :: vec :: Vec < :: apache_avro :: schema :: RecordField >> { let mut schema_fields = Vec :: with_capacity (2usize) ; schema_fields . push (:: apache_avro :: schema :: RecordField { name : "ITEM" . to_string () , doc : None , default : None , aliases : None , schema : < i32 as apache_avro :: AvroSchemaComponent > :: get_schema_in_ctxt (named_schemas , enclosing_namespace) , order : :: apache_avro :: schema :: RecordFieldOrder :: Ascending , position : field_position , custom_attributes : Default :: default () , }) ; field_position += 1 ; schema_fields . push (:: apache_avro :: schema :: RecordField { name : "DOUBLE_ITEM" . to_string () , doc : None , default : None , aliases : None , schema : < i32 as apache_avro :: AvroSchemaComponent > :: get_schema_in_ctxt (named_schemas , enclosing_namespace) , order : :: apache_avro :: schema :: RecordFieldOrder :: Ascending , position : field_position , custom_attributes : Default :: default () , }) ; field_position += 1 ; Some (schema_fields) } }"#;
843                let schema_token_stream = schema_res.unwrap().to_string();
844                assert_eq!(schema_token_stream, expected_token_stream);
845            }
846            Err(error) => panic!(
847                "Failed to parse as derive input when it should be able to. Error: {error:?}"
848            ),
849        };
850
851        let test_enum = quote! {
852            #[serde(rename_all="SCREAMING_SNAKE_CASE")]
853            enum B {
854                Item,
855                DoubleItem,
856            }
857        };
858
859        match syn::parse2::<DeriveInput>(test_enum) {
860            Ok(input) => {
861                let schema_res = derive_avro_schema(input);
862                let expected_token_stream = r#"# [automatically_derived] impl :: apache_avro :: AvroSchemaComponent for B { fn get_schema_in_ctxt (named_schemas : & mut :: apache_avro :: schema :: Names , enclosing_namespace : & :: std :: option :: Option < :: std :: string :: String >) -> :: apache_avro :: schema :: Schema { let name = apache_avro :: schema :: Name :: new ("B") . expect (concat ! ("Unable to parse schema name " , "B")) . fully_qualified_name (enclosing_namespace) ; if named_schemas . contains_key (& name) { apache_avro :: schema :: Schema :: Ref { name } } else { let enclosing_namespace = & name . namespace ; named_schemas . insert (name . clone () , apache_avro :: schema :: Schema :: Ref { name : name . clone () }) ; let schema = apache_avro :: schema :: Schema :: Enum (apache_avro :: schema :: EnumSchema { name : apache_avro :: schema :: Name :: new ("B") . expect (& format ! ("Unable to parse enum name for schema {}" , "B") [..]) , aliases : None , doc : None , symbols : vec ! ["ITEM" . to_owned () , "DOUBLE_ITEM" . to_owned ()] , default : None , attributes : Default :: default () , }) ; named_schemas . insert (name , schema . clone ()) ; schema } } fn get_record_fields_in_ctxt (mut field_position : usize , named_schemas : & mut :: apache_avro :: schema :: Names , enclosing_namespace : & :: std :: option :: Option < :: std :: string :: String >) -> :: std :: option :: Option < :: std :: vec :: Vec < :: apache_avro :: schema :: RecordField >> { None } }"#;
863                let schema_token_stream = schema_res.unwrap().to_string();
864                assert_eq!(schema_token_stream, expected_token_stream);
865            }
866            Err(error) => panic!(
867                "Failed to parse as derive input when it should be able to. Error: {error:?}"
868            ),
869        };
870    }
871
872    #[test]
873    fn test_avro_rs_207_rename_attr_has_priority_over_rename_all_attribute() {
874        let test_struct = quote! {
875            #[serde(rename_all="SCREAMING_SNAKE_CASE")]
876            struct A {
877                item: i32,
878                #[serde(rename="DoubleItem")]
879                double_item: i32
880            }
881        };
882
883        match syn::parse2::<DeriveInput>(test_struct) {
884            Ok(input) => {
885                let schema_res = derive_avro_schema(input);
886                let expected_token_stream = r#"# [automatically_derived] impl :: apache_avro :: AvroSchemaComponent for A { fn get_schema_in_ctxt (named_schemas : & mut :: apache_avro :: schema :: Names , enclosing_namespace : & :: std :: option :: Option < :: std :: string :: String >) -> :: apache_avro :: schema :: Schema { let name = apache_avro :: schema :: Name :: new ("A") . expect (concat ! ("Unable to parse schema name " , "A")) . fully_qualified_name (enclosing_namespace) ; if named_schemas . contains_key (& name) { apache_avro :: schema :: Schema :: Ref { name } } else { let enclosing_namespace = & name . namespace ; named_schemas . insert (name . clone () , apache_avro :: schema :: Schema :: Ref { name : name . clone () }) ; let schema = { let mut schema_fields = Vec :: with_capacity (2usize) ; let mut field_position = 0 ; schema_fields . push (:: apache_avro :: schema :: RecordField { name : "ITEM" . to_string () , doc : None , default : None , aliases : None , schema : < i32 as apache_avro :: AvroSchemaComponent > :: get_schema_in_ctxt (named_schemas , enclosing_namespace) , order : :: apache_avro :: schema :: RecordFieldOrder :: Ascending , position : field_position , custom_attributes : Default :: default () , }) ; field_position += 1 ; schema_fields . push (:: apache_avro :: schema :: RecordField { name : "DoubleItem" . to_string () , doc : None , default : None , aliases : None , schema : < i32 as apache_avro :: AvroSchemaComponent > :: get_schema_in_ctxt (named_schemas , enclosing_namespace) , order : :: apache_avro :: schema :: RecordFieldOrder :: Ascending , position : field_position , custom_attributes : Default :: default () , }) ; field_position += 1 ; let schema_field_set : :: std :: collections :: HashSet < _ > = schema_fields . iter () . map (| rf | & rf . name) . collect () ; assert_eq ! (schema_fields . len () , schema_field_set . len () , "Duplicate field names found: {schema_fields:?}") ; let name = apache_avro :: schema :: Name :: new ("A") . expect (& format ! ("Unable to parse struct name for schema {}" , "A") [..]) ; let lookup : std :: collections :: BTreeMap < String , usize > = schema_fields . iter () . map (| field | (field . name . to_owned () , field . position)) . collect () ; apache_avro :: schema :: Schema :: Record (apache_avro :: schema :: RecordSchema { name , aliases : None , doc : None , fields : schema_fields , lookup , attributes : Default :: default () , }) } ; named_schemas . insert (name , schema . clone ()) ; schema } } fn get_record_fields_in_ctxt (mut field_position : usize , named_schemas : & mut :: apache_avro :: schema :: Names , enclosing_namespace : & :: std :: option :: Option < :: std :: string :: String >) -> :: std :: option :: Option < :: std :: vec :: Vec < :: apache_avro :: schema :: RecordField >> { let mut schema_fields = Vec :: with_capacity (2usize) ; schema_fields . push (:: apache_avro :: schema :: RecordField { name : "ITEM" . to_string () , doc : None , default : None , aliases : None , schema : < i32 as apache_avro :: AvroSchemaComponent > :: get_schema_in_ctxt (named_schemas , enclosing_namespace) , order : :: apache_avro :: schema :: RecordFieldOrder :: Ascending , position : field_position , custom_attributes : Default :: default () , }) ; field_position += 1 ; schema_fields . push (:: apache_avro :: schema :: RecordField { name : "DoubleItem" . to_string () , doc : None , default : None , aliases : None , schema : < i32 as apache_avro :: AvroSchemaComponent > :: get_schema_in_ctxt (named_schemas , enclosing_namespace) , order : :: apache_avro :: schema :: RecordFieldOrder :: Ascending , position : field_position , custom_attributes : Default :: default () , }) ; field_position += 1 ; Some (schema_fields) } }"#;
887                let schema_token_stream = schema_res.unwrap().to_string();
888                assert_eq!(schema_token_stream, expected_token_stream);
889            }
890            Err(error) => panic!(
891                "Failed to parse as derive input when it should be able to. Error: {error:?}"
892            ),
893        };
894    }
895}