apache_avro_derive/
lib.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use darling::FromAttributes;
19use proc_macro2::{Span, TokenStream};
20use quote::quote;
21
22use syn::{
23    parse_macro_input, spanned::Spanned, AttrStyle, Attribute, DeriveInput, Ident, Meta, Type,
24    TypePath,
25};
26
27#[derive(darling::FromAttributes)]
28#[darling(attributes(avro))]
29struct FieldOptions {
30    #[darling(default)]
31    doc: Option<String>,
32    #[darling(default)]
33    default: Option<String>,
34    #[darling(multiple)]
35    alias: Vec<String>,
36    #[darling(default)]
37    rename: Option<String>,
38    #[darling(default)]
39    skip: Option<bool>,
40}
41
42#[derive(darling::FromAttributes)]
43#[darling(attributes(avro))]
44struct NamedTypeOptions {
45    #[darling(default)]
46    namespace: Option<String>,
47    #[darling(default)]
48    doc: Option<String>,
49    #[darling(multiple)]
50    alias: Vec<String>,
51}
52
53#[proc_macro_derive(AvroSchema, attributes(avro))]
54// Templated from Serde
55pub fn proc_macro_derive_avro_schema(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
56    let mut input = parse_macro_input!(input as DeriveInput);
57    derive_avro_schema(&mut input)
58        .unwrap_or_else(to_compile_errors)
59        .into()
60}
61
62fn derive_avro_schema(input: &mut DeriveInput) -> Result<TokenStream, Vec<syn::Error>> {
63    let named_type_options =
64        NamedTypeOptions::from_attributes(&input.attrs[..]).map_err(darling_to_syn)?;
65    let full_schema_name = vec![named_type_options.namespace, Some(input.ident.to_string())]
66        .into_iter()
67        .flatten()
68        .collect::<Vec<String>>()
69        .join(".");
70    let schema_def = match &input.data {
71        syn::Data::Struct(s) => get_data_struct_schema_def(
72            &full_schema_name,
73            named_type_options
74                .doc
75                .or_else(|| extract_outer_doc(&input.attrs)),
76            named_type_options.alias,
77            s,
78            input.ident.span(),
79        )?,
80        syn::Data::Enum(e) => get_data_enum_schema_def(
81            &full_schema_name,
82            named_type_options
83                .doc
84                .or_else(|| extract_outer_doc(&input.attrs)),
85            named_type_options.alias,
86            e,
87            input.ident.span(),
88        )?,
89        _ => {
90            return Err(vec![syn::Error::new(
91                input.ident.span(),
92                "AvroSchema derive only works for structs and simple enums ",
93            )])
94        }
95    };
96    let ident = &input.ident;
97    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
98    Ok(quote! {
99        impl #impl_generics apache_avro::schema::derive::AvroSchemaComponent for #ident #ty_generics #where_clause {
100            fn get_schema_in_ctxt(named_schemas: &mut std::collections::HashMap<apache_avro::schema::Name, apache_avro::schema::Schema>, enclosing_namespace: &Option<String>) -> apache_avro::schema::Schema {
101                let name =  apache_avro::schema::Name::new(#full_schema_name).expect(&format!("Unable to parse schema name {}", #full_schema_name)[..]).fully_qualified_name(enclosing_namespace);
102                let enclosing_namespace = &name.namespace;
103                if named_schemas.contains_key(&name) {
104                    apache_avro::schema::Schema::Ref{name: name.clone()}
105                } else {
106                    named_schemas.insert(name.clone(), apache_avro::schema::Schema::Ref{name: name.clone()});
107                    #schema_def
108                }
109            }
110        }
111    })
112}
113
114fn get_data_struct_schema_def(
115    full_schema_name: &str,
116    record_doc: Option<String>,
117    aliases: Vec<String>,
118    s: &syn::DataStruct,
119    error_span: Span,
120) -> Result<TokenStream, Vec<syn::Error>> {
121    let mut record_field_exprs = vec![];
122    match s.fields {
123        syn::Fields::Named(ref a) => {
124            let mut index: usize = 0;
125            for field in a.named.iter() {
126                let mut name = field.ident.as_ref().unwrap().to_string(); // we know everything has a name
127                if let Some(raw_name) = name.strip_prefix("r#") {
128                    name = raw_name.to_string();
129                }
130                let field_attrs =
131                    FieldOptions::from_attributes(&field.attrs[..]).map_err(darling_to_syn)?;
132                let doc =
133                    preserve_optional(field_attrs.doc.or_else(|| extract_outer_doc(&field.attrs)));
134                if let Some(rename) = field_attrs.rename {
135                    name = rename
136                }
137                if let Some(true) = field_attrs.skip {
138                    continue;
139                }
140                let default_value = match field_attrs.default {
141                    Some(default_value) => {
142                        let _: serde_json::Value = serde_json::from_str(&default_value[..])
143                            .map_err(|e| {
144                                vec![syn::Error::new(
145                                    field.ident.span(),
146                                    format!("Invalid avro default json: \n{e}"),
147                                )]
148                            })?;
149                        quote! {
150                            Some(serde_json::from_str(#default_value).expect(format!("Invalid JSON: {:?}", #default_value).as_str()))
151                        }
152                    }
153                    None => quote! { None },
154                };
155                let aliases = preserve_vec(field_attrs.alias);
156                let schema_expr = type_to_schema_expr(&field.ty)?;
157                let position = index;
158                record_field_exprs.push(quote! {
159                    apache_avro::schema::RecordField {
160                            name: #name.to_string(),
161                            doc: #doc,
162                            default: #default_value,
163                            aliases: #aliases,
164                            schema: #schema_expr,
165                            order: apache_avro::schema::RecordFieldOrder::Ascending,
166                            position: #position,
167                            custom_attributes: Default::default(),
168                        }
169                });
170                index += 1;
171            }
172        }
173        syn::Fields::Unnamed(_) => {
174            return Err(vec![syn::Error::new(
175                error_span,
176                "AvroSchema derive does not work for tuple structs",
177            )])
178        }
179        syn::Fields::Unit => {
180            return Err(vec![syn::Error::new(
181                error_span,
182                "AvroSchema derive does not work for unit structs",
183            )])
184        }
185    }
186    let record_doc = preserve_optional(record_doc);
187    let record_aliases = preserve_vec(aliases);
188    Ok(quote! {
189        let schema_fields = vec![#(#record_field_exprs),*];
190        let name = apache_avro::schema::Name::new(#full_schema_name).expect(&format!("Unable to parse struct name for schema {}", #full_schema_name)[..]);
191        let lookup: std::collections::BTreeMap<String, usize> = schema_fields
192            .iter()
193            .map(|field| (field.name.to_owned(), field.position))
194            .collect();
195        apache_avro::schema::Schema::Record(apache_avro::schema::RecordSchema {
196            name,
197            aliases: #record_aliases,
198            doc: #record_doc,
199            fields: schema_fields,
200            lookup,
201            attributes: Default::default(),
202        })
203    })
204}
205
206fn get_data_enum_schema_def(
207    full_schema_name: &str,
208    doc: Option<String>,
209    aliases: Vec<String>,
210    e: &syn::DataEnum,
211    error_span: Span,
212) -> Result<TokenStream, Vec<syn::Error>> {
213    let doc = preserve_optional(doc);
214    let enum_aliases = preserve_vec(aliases);
215    if e.variants.iter().all(|v| syn::Fields::Unit == v.fields) {
216        let default_value = default_enum_variant(e, error_span)?;
217        let default = preserve_optional(default_value);
218        let symbols: Vec<String> = e
219            .variants
220            .iter()
221            .map(|variant| variant.ident.to_string())
222            .collect();
223        Ok(quote! {
224            apache_avro::schema::Schema::Enum(apache_avro::schema::EnumSchema {
225                name: apache_avro::schema::Name::new(#full_schema_name).expect(&format!("Unable to parse enum name for schema {}", #full_schema_name)[..]),
226                aliases: #enum_aliases,
227                doc: #doc,
228                symbols: vec![#(#symbols.to_owned()),*],
229                default: #default,
230                attributes: Default::default(),
231            })
232        })
233    } else {
234        Err(vec![syn::Error::new(
235            error_span,
236            "AvroSchema derive does not work for enums with non unit structs",
237        )])
238    }
239}
240
241/// Takes in the Tokens of a type and returns the tokens of an expression with return type `Schema`
242fn type_to_schema_expr(ty: &Type) -> Result<TokenStream, Vec<syn::Error>> {
243    if let Type::Path(p) = ty {
244        let type_string = p.path.segments.last().unwrap().ident.to_string();
245
246        let schema = match &type_string[..] {
247            "bool" => quote! {apache_avro::schema::Schema::Boolean},
248            "i8" | "i16" | "i32" | "u8" | "u16" => quote! {apache_avro::schema::Schema::Int},
249            "u32" | "i64" => quote! {apache_avro::schema::Schema::Long},
250            "f32" => quote! {apache_avro::schema::Schema::Float},
251            "f64" => quote! {apache_avro::schema::Schema::Double},
252            "String" | "str" => quote! {apache_avro::schema::Schema::String},
253            "char" => {
254                return Err(vec![syn::Error::new_spanned(
255                    ty,
256                    "AvroSchema: Cannot guarantee successful deserialization of this type",
257                )])
258            }
259            "u64" => {
260                return Err(vec![syn::Error::new_spanned(
261                ty,
262                "Cannot guarantee successful serialization of this type due to overflow concerns",
263            )])
264            } // Can't guarantee serialization type
265            _ => {
266                // Fails when the type does not implement AvroSchemaComponent directly
267                // TODO check and error report with something like https://docs.rs/quote/1.0.15/quote/macro.quote_spanned.html#example
268                type_path_schema_expr(p)
269            }
270        };
271        Ok(schema)
272    } else if let Type::Array(ta) = ty {
273        let inner_schema_expr = type_to_schema_expr(&ta.elem)?;
274        Ok(quote! {apache_avro::schema::Schema::array(#inner_schema_expr)})
275    } else if let Type::Reference(tr) = ty {
276        type_to_schema_expr(&tr.elem)
277    } else {
278        Err(vec![syn::Error::new_spanned(
279            ty,
280            format!("Unable to generate schema for type: {ty:?}"),
281        )])
282    }
283}
284
285fn default_enum_variant(
286    data_enum: &syn::DataEnum,
287    error_span: Span,
288) -> Result<Option<String>, Vec<syn::Error>> {
289    match data_enum
290        .variants
291        .iter()
292        .filter(|v| v.attrs.iter().any(is_default_attr))
293        .collect::<Vec<_>>()
294    {
295        variants if variants.is_empty() => Ok(None),
296        single if single.len() == 1 => Ok(Some(single[0].ident.to_string())),
297        multiple => Err(vec![syn::Error::new(
298            error_span,
299            format!(
300                "Multiple defaults defined: {:?}",
301                multiple
302                    .iter()
303                    .map(|v| v.ident.to_string())
304                    .collect::<Vec<String>>()
305            ),
306        )]),
307    }
308}
309
310fn is_default_attr(attr: &Attribute) -> bool {
311    matches!(attr, Attribute { meta: Meta::Path(path), .. } if path.get_ident().map(Ident::to_string).as_deref() == Some("default"))
312}
313
314/// Generates the schema def expression for fully qualified type paths using the associated function
315/// - `A -> <A as apache_avro::schema::derive::AvroSchemaComponent>::get_schema_in_ctxt()`
316/// - `A<T> -> <A<T> as apache_avro::schema::derive::AvroSchemaComponent>::get_schema_in_ctxt()`
317fn type_path_schema_expr(p: &TypePath) -> TokenStream {
318    quote! {<#p as apache_avro::schema::derive::AvroSchemaComponent>::get_schema_in_ctxt(named_schemas, enclosing_namespace)}
319}
320
321/// Stolen from serde
322fn to_compile_errors(errors: Vec<syn::Error>) -> proc_macro2::TokenStream {
323    let compile_errors = errors.iter().map(syn::Error::to_compile_error);
324    quote!(#(#compile_errors)*)
325}
326
327fn extract_outer_doc(attributes: &[Attribute]) -> Option<String> {
328    let doc = attributes
329        .iter()
330        .filter(|attr| attr.style == AttrStyle::Outer && attr.path().is_ident("doc"))
331        .filter_map(|attr| {
332            let name_value = attr.meta.require_name_value();
333            match name_value {
334                Ok(name_value) => match &name_value.value {
335                    syn::Expr::Lit(expr_lit) => match expr_lit.lit {
336                        syn::Lit::Str(ref lit_str) => Some(lit_str.value().trim().to_string()),
337                        _ => None,
338                    },
339                    _ => None,
340                },
341                Err(_) => None,
342            }
343        })
344        .collect::<Vec<String>>()
345        .join("\n");
346    if doc.is_empty() {
347        None
348    } else {
349        Some(doc)
350    }
351}
352
353fn preserve_optional(op: Option<impl quote::ToTokens>) -> TokenStream {
354    match op {
355        Some(tt) => quote! {Some(#tt.into())},
356        None => quote! {None},
357    }
358}
359
360fn preserve_vec(op: Vec<impl quote::ToTokens>) -> TokenStream {
361    let items: Vec<TokenStream> = op.iter().map(|tt| quote! {#tt.into()}).collect();
362    if items.is_empty() {
363        quote! {None}
364    } else {
365        quote! {Some(vec![#(#items),*])}
366    }
367}
368
369fn darling_to_syn(e: darling::Error) -> Vec<syn::Error> {
370    let msg = format!("{e}");
371    let token_errors = e.write_errors();
372    vec![syn::Error::new(token_errors.span(), msg)]
373}
374
375#[cfg(test)]
376mod tests {
377    use super::*;
378    #[test]
379    fn basic_case() {
380        let test_struct = quote! {
381            struct A {
382                a: i32,
383                b: String
384            }
385        };
386
387        match syn::parse2::<DeriveInput>(test_struct) {
388            Ok(mut input) => {
389                assert!(derive_avro_schema(&mut input).is_ok())
390            }
391            Err(error) => panic!(
392                "Failed to parse as derive input when it should be able to. Error: {error:?}"
393            ),
394        };
395    }
396
397    #[test]
398    fn tuple_struct_unsupported() {
399        let test_tuple_struct = quote! {
400            struct B (i32, String);
401        };
402
403        match syn::parse2::<DeriveInput>(test_tuple_struct) {
404            Ok(mut input) => {
405                assert!(derive_avro_schema(&mut input).is_err())
406            }
407            Err(error) => panic!(
408                "Failed to parse as derive input when it should be able to. Error: {error:?}"
409            ),
410        };
411    }
412
413    #[test]
414    fn unit_struct_unsupported() {
415        let test_tuple_struct = quote! {
416            struct AbsoluteUnit;
417        };
418
419        match syn::parse2::<DeriveInput>(test_tuple_struct) {
420            Ok(mut input) => {
421                assert!(derive_avro_schema(&mut input).is_err())
422            }
423            Err(error) => panic!(
424                "Failed to parse as derive input when it should be able to. Error: {error:?}"
425            ),
426        };
427    }
428
429    #[test]
430    fn struct_with_optional() {
431        let struct_with_optional = quote! {
432            struct Test4 {
433                a : Option<i32>
434            }
435        };
436        match syn::parse2::<DeriveInput>(struct_with_optional) {
437            Ok(mut input) => {
438                assert!(derive_avro_schema(&mut input).is_ok())
439            }
440            Err(error) => panic!(
441                "Failed to parse as derive input when it should be able to. Error: {error:?}"
442            ),
443        };
444    }
445
446    #[test]
447    fn test_basic_enum() {
448        let basic_enum = quote! {
449            enum Basic {
450                A,
451                B,
452                C,
453                D
454            }
455        };
456        match syn::parse2::<DeriveInput>(basic_enum) {
457            Ok(mut input) => {
458                assert!(derive_avro_schema(&mut input).is_ok())
459            }
460            Err(error) => panic!(
461                "Failed to parse as derive input when it should be able to. Error: {error:?}"
462            ),
463        };
464    }
465
466    #[test]
467    fn avro_3687_basic_enum_with_default() {
468        let basic_enum = quote! {
469            enum Basic {
470                #[default]
471                A,
472                B,
473                C,
474                D
475            }
476        };
477        match syn::parse2::<DeriveInput>(basic_enum) {
478            Ok(mut input) => {
479                let derived = derive_avro_schema(&mut input);
480                assert!(derived.is_ok());
481                assert_eq!(derived.unwrap().to_string(), quote! {
482                    impl apache_avro::schema::derive::AvroSchemaComponent for Basic {
483                        fn get_schema_in_ctxt(
484                            named_schemas: &mut std::collections::HashMap<
485                                apache_avro::schema::Name,
486                                apache_avro::schema::Schema
487                            >,
488                            enclosing_namespace: &Option<String>
489                        ) -> apache_avro::schema::Schema {
490                            let name = apache_avro::schema::Name::new("Basic")
491                                .expect(&format!("Unable to parse schema name {}", "Basic")[..])
492                                .fully_qualified_name(enclosing_namespace);
493                            let enclosing_namespace = &name.namespace;
494                            if named_schemas.contains_key(&name) {
495                                apache_avro::schema::Schema::Ref { name: name.clone() }
496                            } else {
497                                named_schemas.insert(
498                                    name.clone(),
499                                    apache_avro::schema::Schema::Ref { name: name.clone() }
500                                );
501                                apache_avro::schema::Schema::Enum(apache_avro::schema::EnumSchema {
502                                    name: apache_avro::schema::Name::new("Basic").expect(
503                                        &format!("Unable to parse enum name for schema {}", "Basic")[..]
504                                    ),
505                                    aliases: None,
506                                    doc: None,
507                                    symbols: vec![
508                                        "A".to_owned(),
509                                        "B".to_owned(),
510                                        "C".to_owned(),
511                                        "D".to_owned()
512                                    ],
513                                    default: Some("A".into()),
514                                    attributes: Default::default(),
515                                })
516                            }
517                        }
518                    }
519                }.to_string());
520            }
521            Err(error) => panic!(
522                "Failed to parse as derive input when it should be able to. Error: {error:?}"
523            ),
524        };
525    }
526
527    #[test]
528    fn avro_3687_basic_enum_with_default_twice() {
529        let non_basic_enum = quote! {
530            enum Basic {
531                #[default]
532                A,
533                B,
534                #[default]
535                C,
536                D
537            }
538        };
539        match syn::parse2::<DeriveInput>(non_basic_enum) {
540            Ok(mut input) => match derive_avro_schema(&mut input) {
541                Ok(_) => {
542                    panic!("Should not be able to derive schema for enum with multiple defaults")
543                }
544                Err(errors) => {
545                    assert_eq!(errors.len(), 1);
546                    assert_eq!(
547                        errors[0].to_string(),
548                        r#"Multiple defaults defined: ["A", "C"]"#
549                    );
550                }
551            },
552            Err(error) => panic!(
553                "Failed to parse as derive input when it should be able to. Error: {error:?}"
554            ),
555        };
556    }
557
558    #[test]
559    fn test_non_basic_enum() {
560        let non_basic_enum = quote! {
561            enum Basic {
562                A(i32),
563                B,
564                C,
565                D
566            }
567        };
568        match syn::parse2::<DeriveInput>(non_basic_enum) {
569            Ok(mut input) => {
570                assert!(derive_avro_schema(&mut input).is_err())
571            }
572            Err(error) => panic!(
573                "Failed to parse as derive input when it should be able to. Error: {error:?}"
574            ),
575        };
576    }
577
578    #[test]
579    fn test_namespace() {
580        let test_struct = quote! {
581            #[avro(namespace = "namespace.testing")]
582            struct A {
583                a: i32,
584                b: String
585            }
586        };
587
588        match syn::parse2::<DeriveInput>(test_struct) {
589            Ok(mut input) => {
590                let schema_token_stream = derive_avro_schema(&mut input);
591                assert!(&schema_token_stream.is_ok());
592                assert!(schema_token_stream
593                    .unwrap()
594                    .to_string()
595                    .contains("namespace.testing"))
596            }
597            Err(error) => panic!(
598                "Failed to parse as derive input when it should be able to. Error: {error:?}"
599            ),
600        };
601    }
602
603    #[test]
604    fn test_reference() {
605        let test_reference_struct = quote! {
606            struct A<'a> {
607                a: &'a Vec<i32>,
608                b: &'static str
609            }
610        };
611
612        match syn::parse2::<DeriveInput>(test_reference_struct) {
613            Ok(mut input) => {
614                assert!(derive_avro_schema(&mut input).is_ok())
615            }
616            Err(error) => panic!(
617                "Failed to parse as derive input when it should be able to. Error: {error:?}"
618            ),
619        };
620    }
621
622    #[test]
623    fn test_trait_cast() {
624        assert_eq!(type_path_schema_expr(&syn::parse2::<TypePath>(quote!{i32}).unwrap()).to_string(), quote!{<i32 as apache_avro::schema::derive::AvroSchemaComponent>::get_schema_in_ctxt(named_schemas, enclosing_namespace)}.to_string());
625        assert_eq!(type_path_schema_expr(&syn::parse2::<TypePath>(quote!{Vec<T>}).unwrap()).to_string(), quote!{<Vec<T> as apache_avro::schema::derive::AvroSchemaComponent>::get_schema_in_ctxt(named_schemas, enclosing_namespace)}.to_string());
626        assert_eq!(type_path_schema_expr(&syn::parse2::<TypePath>(quote!{AnyType}).unwrap()).to_string(), quote!{<AnyType as apache_avro::schema::derive::AvroSchemaComponent>::get_schema_in_ctxt(named_schemas, enclosing_namespace)}.to_string());
627    }
628
629    #[test]
630    fn test_avro_3709_record_field_attributes() {
631        let test_struct = quote! {
632            struct A {
633                #[avro(alias = "a1", alias = "a2", doc = "a doc", default = "123", rename = "a3")]
634                a: i32
635            }
636        };
637
638        match syn::parse2::<DeriveInput>(test_struct) {
639            Ok(mut input) => {
640                let schema_res = derive_avro_schema(&mut input);
641                let expected_token_stream = r#"let schema_fields = vec ! [apache_avro :: schema :: RecordField { name : "a3" . to_string () , doc : Some ("a doc" . into ()) , default : Some (serde_json :: from_str ("123") . expect (format ! ("Invalid JSON: {:?}" , "123") . as_str ())) , aliases : Some (vec ! ["a1" . into () , "a2" . into ()]) , schema : apache_avro :: schema :: Schema :: Int , order : apache_avro :: schema :: RecordFieldOrder :: Ascending , position : 0usize , custom_attributes : Default :: default () , }] ;"#;
642                let schema_token_stream = schema_res.unwrap().to_string();
643                assert!(schema_token_stream.contains(expected_token_stream));
644            }
645            Err(error) => panic!(
646                "Failed to parse as derive input when it should be able to. Error: {error:?}"
647            ),
648        };
649    }
650}