apache_avro_derive/
lib.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18#![cfg_attr(nightly, feature(proc_macro_diagnostic))]
19
20//! This crate is the implementation of the `AvroSchema` derive macro.
21//! Please use it via the [`apache-avro`](https://crates.io/crates/apache-avro) crate:
22//!
23//! ```no_run
24//! use apache_avro::AvroSchema;
25//!
26//! #[derive(AvroSchema)]
27//! ```
28//! Please see the documentation of the [`AvroSchema`] trait for instructions on how to use it.
29//!
30//! [`AvroSchema`]: https://docs.rs/apache-avro/latest/apache_avro/serde/trait.AvroSchema.html
31
32mod attributes;
33mod case;
34mod enums;
35
36use proc_macro2::{Span, TokenStream};
37use quote::quote;
38use syn::{
39    DataStruct, DeriveInput, Expr, Field, Fields, Generics, Ident, Type, parse_macro_input,
40    spanned::Spanned,
41};
42
43use crate::enums::get_data_enum_schema_def;
44use crate::{
45    attributes::{FieldDefault, FieldOptions, NamedTypeOptions, With},
46    case::RenameRule,
47};
48
49#[proc_macro_derive(AvroSchema, attributes(avro, serde))]
50// Templated from Serde
51pub fn proc_macro_derive_avro_schema(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
52    let input = parse_macro_input!(input as DeriveInput);
53    derive_avro_schema(input)
54        .unwrap_or_else(to_compile_errors)
55        .into()
56}
57
58fn derive_avro_schema(input: DeriveInput) -> Result<TokenStream, Vec<syn::Error>> {
59    // It would be nice to parse the attributes before the `match`, but we first need to validate that `input` is not a union.
60    // Otherwise a user could get errors related to the attributes and after fixing those get an error because the attributes were on a union.
61    let input_span = input.span();
62    match input.data {
63        syn::Data::Struct(data_struct) => {
64            let named_type_options = NamedTypeOptions::new(&input.ident, &input.attrs, input_span)?;
65            let (get_schema_impl, get_record_fields_impl) = if named_type_options.transparent {
66                get_transparent_struct_schema_def(data_struct.fields, input_span)?
67            } else {
68                let (schema_def, record_fields) =
69                    get_struct_schema_def(&named_type_options, data_struct, input.ident.span())?;
70                (
71                    handle_named_schemas(named_type_options.name, schema_def),
72                    record_fields,
73                )
74            };
75            Ok(create_trait_definition(
76                input.ident,
77                &input.generics,
78                get_schema_impl,
79                get_record_fields_impl,
80                named_type_options.default,
81            ))
82        }
83        syn::Data::Enum(data_enum) => {
84            let named_type_options = NamedTypeOptions::new(&input.ident, &input.attrs, input_span)?;
85            if named_type_options.transparent {
86                return Err(vec![syn::Error::new(
87                    input_span,
88                    "AvroSchema: `#[serde(transparent)]` is only supported on structs",
89                )]);
90            }
91            let schema_def =
92                get_data_enum_schema_def(&named_type_options, data_enum, input.ident.span())?;
93            let inner = handle_named_schemas(named_type_options.name, schema_def);
94            Ok(create_trait_definition(
95                input.ident,
96                &input.generics,
97                inner,
98                quote! { ::std::option::Option::None },
99                named_type_options.default,
100            ))
101        }
102        syn::Data::Union(_) => Err(vec![syn::Error::new(
103            input_span,
104            "AvroSchema: derive only works for structs and simple enums",
105        )]),
106    }
107}
108
109/// Generate the trait definition with the correct generics
110fn create_trait_definition(
111    ident: Ident,
112    generics: &Generics,
113    get_schema_impl: TokenStream,
114    get_record_fields_impl: TokenStream,
115    field_default_impl: TokenStream,
116) -> TokenStream {
117    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
118    quote! {
119        #[automatically_derived]
120        impl #impl_generics ::apache_avro::AvroSchemaComponent for #ident #ty_generics #where_clause {
121            fn get_schema_in_ctxt(named_schemas: &mut ::std::collections::HashSet<::apache_avro::schema::Name>, enclosing_namespace: ::apache_avro::schema::NamespaceRef) -> ::apache_avro::schema::Schema {
122                #get_schema_impl
123            }
124
125            fn get_record_fields_in_ctxt(named_schemas: &mut ::std::collections::HashSet<::apache_avro::schema::Name>, enclosing_namespace: ::apache_avro::schema::NamespaceRef) -> ::std::option::Option<::std::vec::Vec<::apache_avro::schema::RecordField>> {
126                #get_record_fields_impl
127            }
128
129            fn field_default() -> ::std::option::Option<::serde_json::Value> {
130                ::std::option::Option::#field_default_impl
131            }
132        }
133    }
134}
135
136/// Generate the code to check `named_schemas` if this schema already exist
137fn handle_named_schemas(full_schema_name: String, schema_def: TokenStream) -> TokenStream {
138    quote! {
139        let name = ::apache_avro::schema::Name::new_with_enclosing_namespace(#full_schema_name, enclosing_namespace).expect(concat!("Unable to parse schema name ", #full_schema_name));
140        if named_schemas.contains(&name) {
141            ::apache_avro::schema::Schema::Ref{name}
142        } else {
143            let enclosing_namespace = name.namespace();
144            named_schemas.insert(name.clone());
145            #schema_def
146        }
147    }
148}
149
150/// Generate a schema definition for a struct.
151fn get_struct_schema_def(
152    container_attrs: &NamedTypeOptions,
153    data_struct: DataStruct,
154    ident_span: Span,
155) -> Result<(TokenStream, TokenStream), Vec<syn::Error>> {
156    let mut record_field_exprs = vec![];
157    match data_struct.fields {
158        Fields::Named(a) => {
159            for field in a.named {
160                let mut name = field
161                    .ident
162                    .as_ref()
163                    .expect("Field must have a name")
164                    .to_string();
165                if let Some(raw_name) = name.strip_prefix("r#") {
166                    name = raw_name.to_string();
167                }
168                let field_attrs = FieldOptions::new(&field.attrs, field.span())?;
169                let doc = preserve_optional(field_attrs.doc);
170                match (field_attrs.rename, container_attrs.rename_all) {
171                    (Some(rename), _) => {
172                        name = rename;
173                    }
174                    (None, rename_all) if rename_all != RenameRule::None => {
175                        name = rename_all.apply_to_field(&name);
176                    }
177                    _ => {}
178                }
179                if field_attrs.skip {
180                    continue;
181                } else if field_attrs.flatten {
182                    // Inline the fields of the child record at runtime, as we don't have access to
183                    // the schema here.
184                    let get_record_fields =
185                        get_field_get_record_fields_expr(&field, field_attrs.with)?;
186                    record_field_exprs.push(quote! {
187                        if let Some(flattened_fields) = #get_record_fields {
188                            schema_fields.extend(flattened_fields);
189                        } else {
190                            panic!("{} does not have any fields to flatten to", stringify!(#field));
191                        }
192                    });
193
194                    // Don't add this field as it's been replaced by the child record fields
195                    continue;
196                }
197                let default_value = match field_attrs.default {
198                    FieldDefault::Disabled => quote! { ::std::option::Option::None },
199                    FieldDefault::Trait => type_to_field_default_expr(&field.ty)?,
200                    FieldDefault::Value(default_value) => {
201                        let _: serde_json::Value = serde_json::from_str(&default_value[..])
202                            .map_err(|e| {
203                                vec![syn::Error::new(
204                                    field.ident.span(),
205                                    format!("Invalid avro default json: \n{e}"),
206                                )]
207                            })?;
208                        quote! {
209                            ::std::option::Option::Some(::serde_json::from_str(#default_value).expect("Unreachable! This parsed at compile time!"))
210                        }
211                    }
212                };
213                let aliases = field_aliases(&field_attrs.alias);
214                let schema_expr = get_field_schema_expr(&field, field_attrs.with)?;
215                record_field_exprs.push(quote! {
216                    schema_fields.push(::apache_avro::schema::RecordField {
217                        name: #name.to_string(),
218                        doc: #doc,
219                        default: #default_value,
220                        aliases: #aliases,
221                        schema: #schema_expr,
222                        custom_attributes: ::std::collections::BTreeMap::new(),
223                    });
224                });
225            }
226        }
227        Fields::Unnamed(_) => {
228            return Err(vec![syn::Error::new(
229                ident_span,
230                "AvroSchema derive does not work for tuple structs",
231            )]);
232        }
233        Fields::Unit => {
234            return Err(vec![syn::Error::new(
235                ident_span,
236                "AvroSchema derive does not work for unit structs",
237            )]);
238        }
239    }
240
241    let record_doc = preserve_optional(container_attrs.doc.as_ref());
242    let record_aliases = aliases(&container_attrs.aliases);
243    let full_schema_name = &container_attrs.name;
244
245    // When flatten is involved, there will be more but we don't know how many. This optimises for
246    // the most common case where there is no flatten.
247    let minimum_fields = record_field_exprs.len();
248
249    let schema_def = quote! {
250        {
251            let mut schema_fields = ::std::vec::Vec::with_capacity(#minimum_fields);
252            #(#record_field_exprs)*
253            let schema_field_set: ::std::collections::HashSet<_> = schema_fields.iter().map(|rf| &rf.name).collect();
254            assert_eq!(schema_fields.len(), schema_field_set.len(), "Duplicate field names found: {schema_fields:?}");
255            let name = ::apache_avro::schema::Name::new(#full_schema_name).expect(&format!("Unable to parse struct name for schema {}", #full_schema_name)[..]);
256            let lookup: ::std::collections::BTreeMap<String, usize> = schema_fields
257                .iter()
258                .enumerate()
259                .map(|(position, field)| (field.name.to_owned(), position))
260                .collect();
261            ::apache_avro::schema::Schema::Record(::apache_avro::schema::RecordSchema {
262                name,
263                aliases: #record_aliases,
264                doc: #record_doc,
265                fields: schema_fields,
266                lookup,
267                attributes: ::std::collections::BTreeMap::new(),
268            })
269        }
270    };
271    let record_fields = quote! {
272        let mut schema_fields = ::std::vec::Vec::with_capacity(#minimum_fields);
273        #(#record_field_exprs)*
274        ::std::option::Option::Some(schema_fields)
275    };
276
277    Ok((schema_def, record_fields))
278}
279
280/// Use the schema definition of the only field in the struct as the schema
281fn get_transparent_struct_schema_def(
282    fields: Fields,
283    input_span: Span,
284) -> Result<(TokenStream, TokenStream), Vec<syn::Error>> {
285    match fields {
286        Fields::Named(fields_named) => {
287            let mut found = None;
288            for field in fields_named.named {
289                let attrs = FieldOptions::new(&field.attrs, field.span())?;
290                if attrs.skip {
291                    continue;
292                }
293                if found.replace((field, attrs)).is_some() {
294                    return Err(vec![syn::Error::new(
295                        input_span,
296                        "AvroSchema: #[serde(transparent)] is only allowed on structs with one unskipped field",
297                    )]);
298                }
299            }
300
301            if let Some((field, attrs)) = found {
302                Ok((
303                    get_field_schema_expr(&field, attrs.with.clone())?,
304                    get_field_get_record_fields_expr(&field, attrs.with)?,
305                ))
306            } else {
307                Err(vec![syn::Error::new(
308                    input_span,
309                    "AvroSchema: #[serde(transparent)] is only allowed on structs with one unskipped field",
310                )])
311            }
312        }
313        Fields::Unnamed(_) => Err(vec![syn::Error::new(
314            input_span,
315            "AvroSchema: derive does not work for tuple structs",
316        )]),
317        Fields::Unit => Err(vec![syn::Error::new(
318            input_span,
319            "AvroSchema: derive does not work for unit structs",
320        )]),
321    }
322}
323
324fn get_field_schema_expr(field: &Field, with: With) -> Result<TokenStream, Vec<syn::Error>> {
325    match with {
326        With::Trait => Ok(type_to_schema_expr(&field.ty)?),
327        With::Serde(path) => {
328            Ok(quote! { #path::get_schema_in_ctxt(named_schemas, enclosing_namespace) })
329        }
330        With::Expr(Expr::Closure(closure)) => {
331            if closure.inputs.is_empty() {
332                Ok(quote! { (#closure)() })
333            } else {
334                Err(vec![syn::Error::new(
335                    field.span(),
336                    "Expected closure with 0 parameters",
337                )])
338            }
339        }
340        With::Expr(Expr::Path(path)) => Ok(quote! { #path(named_schemas, enclosing_namespace) }),
341        With::Expr(_expr) => Err(vec![syn::Error::new(
342            field.span(),
343            "Invalid expression, expected function or closure",
344        )]),
345    }
346}
347
348fn get_field_get_record_fields_expr(
349    field: &Field,
350    with: With,
351) -> Result<TokenStream, Vec<syn::Error>> {
352    match with {
353        With::Trait => Ok(type_to_get_record_fields_expr(&field.ty)?),
354        With::Serde(path) => {
355            Ok(quote! { #path::get_record_fields_in_ctxt(named_schemas, enclosing_namespace) })
356        }
357        With::Expr(Expr::Closure(closure)) => {
358            if closure.inputs.is_empty() {
359                Ok(quote! {
360                    ::apache_avro::serde::get_record_fields_in_ctxt(
361                        named_schemas,
362                        enclosing_namespace,
363                        |_, _| (#closure)(),
364                    )
365                })
366            } else {
367                Err(vec![syn::Error::new(
368                    field.span(),
369                    "Expected closure with 0 parameters",
370                )])
371            }
372        }
373        With::Expr(Expr::Path(path)) => Ok(quote! {
374            ::apache_avro::serde::get_record_fields_in_ctxt(named_schemas, enclosing_namespace, #path)
375        }),
376        With::Expr(_expr) => Err(vec![syn::Error::new(
377            field.span(),
378            "Invalid expression, expected function or closure",
379        )]),
380    }
381}
382
383/// Takes in the Tokens of a type and returns the tokens of an expression with return type `Schema`
384fn type_to_schema_expr(ty: &Type) -> Result<TokenStream, Vec<syn::Error>> {
385    match ty {
386        Type::Array(_) | Type::Slice(_) | Type::Path(_) | Type::Reference(_) => Ok(
387            quote! {<#ty as :: apache_avro::AvroSchemaComponent>::get_schema_in_ctxt(named_schemas, enclosing_namespace)},
388        ),
389        Type::Ptr(_) => Err(vec![syn::Error::new_spanned(
390            ty,
391            "AvroSchema: derive does not support raw pointers",
392        )]),
393        Type::Tuple(_) => Err(vec![syn::Error::new_spanned(
394            ty,
395            "AvroSchema: derive does not support tuples",
396        )]),
397        _ => Err(vec![syn::Error::new_spanned(
398            ty,
399            format!(
400                "AvroSchema: Unexpected type encountered! Please open an issue if this kind of type should be supported: {ty:?}"
401            ),
402        )]),
403    }
404}
405
406fn type_to_get_record_fields_expr(ty: &Type) -> Result<TokenStream, Vec<syn::Error>> {
407    match ty {
408        Type::Array(_) | Type::Slice(_) | Type::Path(_) | Type::Reference(_) => Ok(
409            quote! {<#ty as :: apache_avro::AvroSchemaComponent>::get_record_fields_in_ctxt(named_schemas, enclosing_namespace)},
410        ),
411        Type::Ptr(_) => Err(vec![syn::Error::new_spanned(
412            ty,
413            "AvroSchema: derive does not support raw pointers",
414        )]),
415        Type::Tuple(_) => Err(vec![syn::Error::new_spanned(
416            ty,
417            "AvroSchema: derive does not support tuples",
418        )]),
419        _ => Err(vec![syn::Error::new_spanned(
420            ty,
421            format!(
422                "AvroSchema: Unexpected type encountered! Please open an issue if this kind of type should be supported: {ty:?}"
423            ),
424        )]),
425    }
426}
427
428fn type_to_field_default_expr(ty: &Type) -> Result<TokenStream, Vec<syn::Error>> {
429    match ty {
430        Type::Array(_) | Type::Slice(_) | Type::Path(_) | Type::Reference(_) => {
431            Ok(quote! {<#ty as :: apache_avro::AvroSchemaComponent>::field_default()})
432        }
433        Type::Ptr(_) => Err(vec![syn::Error::new_spanned(
434            ty,
435            "AvroSchema: derive does not support raw pointers",
436        )]),
437        Type::Tuple(_) => Err(vec![syn::Error::new_spanned(
438            ty,
439            "AvroSchema: derive does not support tuples",
440        )]),
441        _ => Err(vec![syn::Error::new_spanned(
442            ty,
443            format!(
444                "AvroSchema: Unexpected type encountered! Please open an issue if this kind of type should be supported: {ty:?}"
445            ),
446        )]),
447    }
448}
449
450/// Stolen from serde
451fn to_compile_errors(errors: Vec<syn::Error>) -> proc_macro2::TokenStream {
452    let compile_errors = errors.iter().map(syn::Error::to_compile_error);
453    quote!(#(#compile_errors)*)
454}
455
456fn preserve_optional(op: Option<impl quote::ToTokens>) -> TokenStream {
457    match op {
458        Some(tt) => quote! {::std::option::Option::Some(#tt.into())},
459        None => quote! {::std::option::Option::None},
460    }
461}
462
463fn aliases(op: &[impl quote::ToTokens]) -> TokenStream {
464    let items: Vec<TokenStream> = op
465        .iter()
466        .map(|tt| quote! {#tt.try_into().expect("Alias is invalid")})
467        .collect();
468    if items.is_empty() {
469        quote! {::std::option::Option::None}
470    } else {
471        quote! {::std::option::Option::Some(vec![#(#items),*])}
472    }
473}
474
475fn field_aliases(op: &[impl quote::ToTokens]) -> TokenStream {
476    let items: Vec<TokenStream> = op
477        .iter()
478        .map(|tt| quote! {#tt.try_into().expect("Alias is invalid")})
479        .collect();
480    if items.is_empty() {
481        quote! {::std::vec::Vec::new()}
482    } else {
483        quote! {vec![#(#items),*]}
484    }
485}
486
487#[cfg(test)]
488mod tests {
489    use super::*;
490    use pretty_assertions::assert_eq;
491
492    #[test]
493    fn test_trait_cast() {
494        assert_eq!(type_to_schema_expr(&syn::parse2::<Type>(quote!{i32}).unwrap()).unwrap().to_string(), quote!{<i32 as :: apache_avro::AvroSchemaComponent>::get_schema_in_ctxt(named_schemas, enclosing_namespace)}.to_string());
495        assert_eq!(type_to_schema_expr(&syn::parse2::<Type>(quote!{Vec<T>}).unwrap()).unwrap().to_string(), quote!{<Vec<T> as :: apache_avro::AvroSchemaComponent>::get_schema_in_ctxt(named_schemas, enclosing_namespace)}.to_string());
496        assert_eq!(type_to_schema_expr(&syn::parse2::<Type>(quote!{AnyType}).unwrap()).unwrap().to_string(), quote!{<AnyType as :: apache_avro::AvroSchemaComponent>::get_schema_in_ctxt(named_schemas, enclosing_namespace)}.to_string());
497    }
498}