1use darling::FromAttributes;
19use proc_macro2::{Span, TokenStream};
20use quote::quote;
21
22use syn::{
23 parse_macro_input, spanned::Spanned, AttrStyle, Attribute, DeriveInput, Ident, Meta, Type,
24 TypePath,
25};
26
27#[derive(darling::FromAttributes)]
28#[darling(attributes(avro))]
29struct FieldOptions {
30 #[darling(default)]
31 doc: Option<String>,
32 #[darling(default)]
33 default: Option<String>,
34 #[darling(multiple)]
35 alias: Vec<String>,
36 #[darling(default)]
37 rename: Option<String>,
38 #[darling(default)]
39 skip: Option<bool>,
40}
41
42#[derive(darling::FromAttributes)]
43#[darling(attributes(avro))]
44struct NamedTypeOptions {
45 #[darling(default)]
46 namespace: Option<String>,
47 #[darling(default)]
48 doc: Option<String>,
49 #[darling(multiple)]
50 alias: Vec<String>,
51}
52
53#[proc_macro_derive(AvroSchema, attributes(avro))]
54pub fn proc_macro_derive_avro_schema(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
56 let mut input = parse_macro_input!(input as DeriveInput);
57 derive_avro_schema(&mut input)
58 .unwrap_or_else(to_compile_errors)
59 .into()
60}
61
62fn derive_avro_schema(input: &mut DeriveInput) -> Result<TokenStream, Vec<syn::Error>> {
63 let named_type_options =
64 NamedTypeOptions::from_attributes(&input.attrs[..]).map_err(darling_to_syn)?;
65 let full_schema_name = vec![named_type_options.namespace, Some(input.ident.to_string())]
66 .into_iter()
67 .flatten()
68 .collect::<Vec<String>>()
69 .join(".");
70 let schema_def = match &input.data {
71 syn::Data::Struct(s) => get_data_struct_schema_def(
72 &full_schema_name,
73 named_type_options
74 .doc
75 .or_else(|| extract_outer_doc(&input.attrs)),
76 named_type_options.alias,
77 s,
78 input.ident.span(),
79 )?,
80 syn::Data::Enum(e) => get_data_enum_schema_def(
81 &full_schema_name,
82 named_type_options
83 .doc
84 .or_else(|| extract_outer_doc(&input.attrs)),
85 named_type_options.alias,
86 e,
87 input.ident.span(),
88 )?,
89 _ => {
90 return Err(vec![syn::Error::new(
91 input.ident.span(),
92 "AvroSchema derive only works for structs and simple enums ",
93 )])
94 }
95 };
96 let ident = &input.ident;
97 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
98 Ok(quote! {
99 impl #impl_generics apache_avro::schema::derive::AvroSchemaComponent for #ident #ty_generics #where_clause {
100 fn get_schema_in_ctxt(named_schemas: &mut std::collections::HashMap<apache_avro::schema::Name, apache_avro::schema::Schema>, enclosing_namespace: &Option<String>) -> apache_avro::schema::Schema {
101 let name = apache_avro::schema::Name::new(#full_schema_name).expect(&format!("Unable to parse schema name {}", #full_schema_name)[..]).fully_qualified_name(enclosing_namespace);
102 let enclosing_namespace = &name.namespace;
103 if named_schemas.contains_key(&name) {
104 apache_avro::schema::Schema::Ref{name: name.clone()}
105 } else {
106 named_schemas.insert(name.clone(), apache_avro::schema::Schema::Ref{name: name.clone()});
107 #schema_def
108 }
109 }
110 }
111 })
112}
113
114fn get_data_struct_schema_def(
115 full_schema_name: &str,
116 record_doc: Option<String>,
117 aliases: Vec<String>,
118 s: &syn::DataStruct,
119 error_span: Span,
120) -> Result<TokenStream, Vec<syn::Error>> {
121 let mut record_field_exprs = vec![];
122 match s.fields {
123 syn::Fields::Named(ref a) => {
124 let mut index: usize = 0;
125 for field in a.named.iter() {
126 let mut name = field.ident.as_ref().unwrap().to_string(); if let Some(raw_name) = name.strip_prefix("r#") {
128 name = raw_name.to_string();
129 }
130 let field_attrs =
131 FieldOptions::from_attributes(&field.attrs[..]).map_err(darling_to_syn)?;
132 let doc =
133 preserve_optional(field_attrs.doc.or_else(|| extract_outer_doc(&field.attrs)));
134 if let Some(rename) = field_attrs.rename {
135 name = rename
136 }
137 if let Some(true) = field_attrs.skip {
138 continue;
139 }
140 let default_value = match field_attrs.default {
141 Some(default_value) => {
142 let _: serde_json::Value = serde_json::from_str(&default_value[..])
143 .map_err(|e| {
144 vec![syn::Error::new(
145 field.ident.span(),
146 format!("Invalid avro default json: \n{e}"),
147 )]
148 })?;
149 quote! {
150 Some(serde_json::from_str(#default_value).expect(format!("Invalid JSON: {:?}", #default_value).as_str()))
151 }
152 }
153 None => quote! { None },
154 };
155 let aliases = preserve_vec(field_attrs.alias);
156 let schema_expr = type_to_schema_expr(&field.ty)?;
157 let position = index;
158 record_field_exprs.push(quote! {
159 apache_avro::schema::RecordField {
160 name: #name.to_string(),
161 doc: #doc,
162 default: #default_value,
163 aliases: #aliases,
164 schema: #schema_expr,
165 order: apache_avro::schema::RecordFieldOrder::Ascending,
166 position: #position,
167 custom_attributes: Default::default(),
168 }
169 });
170 index += 1;
171 }
172 }
173 syn::Fields::Unnamed(_) => {
174 return Err(vec![syn::Error::new(
175 error_span,
176 "AvroSchema derive does not work for tuple structs",
177 )])
178 }
179 syn::Fields::Unit => {
180 return Err(vec![syn::Error::new(
181 error_span,
182 "AvroSchema derive does not work for unit structs",
183 )])
184 }
185 }
186 let record_doc = preserve_optional(record_doc);
187 let record_aliases = preserve_vec(aliases);
188 Ok(quote! {
189 let schema_fields = vec![#(#record_field_exprs),*];
190 let name = apache_avro::schema::Name::new(#full_schema_name).expect(&format!("Unable to parse struct name for schema {}", #full_schema_name)[..]);
191 let lookup: std::collections::BTreeMap<String, usize> = schema_fields
192 .iter()
193 .map(|field| (field.name.to_owned(), field.position))
194 .collect();
195 apache_avro::schema::Schema::Record(apache_avro::schema::RecordSchema {
196 name,
197 aliases: #record_aliases,
198 doc: #record_doc,
199 fields: schema_fields,
200 lookup,
201 attributes: Default::default(),
202 })
203 })
204}
205
206fn get_data_enum_schema_def(
207 full_schema_name: &str,
208 doc: Option<String>,
209 aliases: Vec<String>,
210 e: &syn::DataEnum,
211 error_span: Span,
212) -> Result<TokenStream, Vec<syn::Error>> {
213 let doc = preserve_optional(doc);
214 let enum_aliases = preserve_vec(aliases);
215 if e.variants.iter().all(|v| syn::Fields::Unit == v.fields) {
216 let default_value = default_enum_variant(e, error_span)?;
217 let default = preserve_optional(default_value);
218 let symbols: Vec<String> = e
219 .variants
220 .iter()
221 .map(|variant| variant.ident.to_string())
222 .collect();
223 Ok(quote! {
224 apache_avro::schema::Schema::Enum(apache_avro::schema::EnumSchema {
225 name: apache_avro::schema::Name::new(#full_schema_name).expect(&format!("Unable to parse enum name for schema {}", #full_schema_name)[..]),
226 aliases: #enum_aliases,
227 doc: #doc,
228 symbols: vec![#(#symbols.to_owned()),*],
229 default: #default,
230 attributes: Default::default(),
231 })
232 })
233 } else {
234 Err(vec![syn::Error::new(
235 error_span,
236 "AvroSchema derive does not work for enums with non unit structs",
237 )])
238 }
239}
240
241fn type_to_schema_expr(ty: &Type) -> Result<TokenStream, Vec<syn::Error>> {
243 if let Type::Path(p) = ty {
244 let type_string = p.path.segments.last().unwrap().ident.to_string();
245
246 let schema = match &type_string[..] {
247 "bool" => quote! {apache_avro::schema::Schema::Boolean},
248 "i8" | "i16" | "i32" | "u8" | "u16" => quote! {apache_avro::schema::Schema::Int},
249 "u32" | "i64" => quote! {apache_avro::schema::Schema::Long},
250 "f32" => quote! {apache_avro::schema::Schema::Float},
251 "f64" => quote! {apache_avro::schema::Schema::Double},
252 "String" | "str" => quote! {apache_avro::schema::Schema::String},
253 "char" => {
254 return Err(vec![syn::Error::new_spanned(
255 ty,
256 "AvroSchema: Cannot guarantee successful deserialization of this type",
257 )])
258 }
259 "u64" => {
260 return Err(vec![syn::Error::new_spanned(
261 ty,
262 "Cannot guarantee successful serialization of this type due to overflow concerns",
263 )])
264 } _ => {
266 type_path_schema_expr(p)
269 }
270 };
271 Ok(schema)
272 } else if let Type::Array(ta) = ty {
273 let inner_schema_expr = type_to_schema_expr(&ta.elem)?;
274 Ok(quote! {apache_avro::schema::Schema::array(#inner_schema_expr)})
275 } else if let Type::Reference(tr) = ty {
276 type_to_schema_expr(&tr.elem)
277 } else {
278 Err(vec![syn::Error::new_spanned(
279 ty,
280 format!("Unable to generate schema for type: {ty:?}"),
281 )])
282 }
283}
284
285fn default_enum_variant(
286 data_enum: &syn::DataEnum,
287 error_span: Span,
288) -> Result<Option<String>, Vec<syn::Error>> {
289 match data_enum
290 .variants
291 .iter()
292 .filter(|v| v.attrs.iter().any(is_default_attr))
293 .collect::<Vec<_>>()
294 {
295 variants if variants.is_empty() => Ok(None),
296 single if single.len() == 1 => Ok(Some(single[0].ident.to_string())),
297 multiple => Err(vec![syn::Error::new(
298 error_span,
299 format!(
300 "Multiple defaults defined: {:?}",
301 multiple
302 .iter()
303 .map(|v| v.ident.to_string())
304 .collect::<Vec<String>>()
305 ),
306 )]),
307 }
308}
309
310fn is_default_attr(attr: &Attribute) -> bool {
311 matches!(attr, Attribute { meta: Meta::Path(path), .. } if path.get_ident().map(Ident::to_string).as_deref() == Some("default"))
312}
313
314fn type_path_schema_expr(p: &TypePath) -> TokenStream {
318 quote! {<#p as apache_avro::schema::derive::AvroSchemaComponent>::get_schema_in_ctxt(named_schemas, enclosing_namespace)}
319}
320
321fn to_compile_errors(errors: Vec<syn::Error>) -> proc_macro2::TokenStream {
323 let compile_errors = errors.iter().map(syn::Error::to_compile_error);
324 quote!(#(#compile_errors)*)
325}
326
327fn extract_outer_doc(attributes: &[Attribute]) -> Option<String> {
328 let doc = attributes
329 .iter()
330 .filter(|attr| attr.style == AttrStyle::Outer && attr.path().is_ident("doc"))
331 .filter_map(|attr| {
332 let name_value = attr.meta.require_name_value();
333 match name_value {
334 Ok(name_value) => match &name_value.value {
335 syn::Expr::Lit(expr_lit) => match expr_lit.lit {
336 syn::Lit::Str(ref lit_str) => Some(lit_str.value().trim().to_string()),
337 _ => None,
338 },
339 _ => None,
340 },
341 Err(_) => None,
342 }
343 })
344 .collect::<Vec<String>>()
345 .join("\n");
346 if doc.is_empty() {
347 None
348 } else {
349 Some(doc)
350 }
351}
352
353fn preserve_optional(op: Option<impl quote::ToTokens>) -> TokenStream {
354 match op {
355 Some(tt) => quote! {Some(#tt.into())},
356 None => quote! {None},
357 }
358}
359
360fn preserve_vec(op: Vec<impl quote::ToTokens>) -> TokenStream {
361 let items: Vec<TokenStream> = op.iter().map(|tt| quote! {#tt.into()}).collect();
362 if items.is_empty() {
363 quote! {None}
364 } else {
365 quote! {Some(vec![#(#items),*])}
366 }
367}
368
369fn darling_to_syn(e: darling::Error) -> Vec<syn::Error> {
370 let msg = format!("{e}");
371 let token_errors = e.write_errors();
372 vec![syn::Error::new(token_errors.span(), msg)]
373}
374
375#[cfg(test)]
376mod tests {
377 use super::*;
378 #[test]
379 fn basic_case() {
380 let test_struct = quote! {
381 struct A {
382 a: i32,
383 b: String
384 }
385 };
386
387 match syn::parse2::<DeriveInput>(test_struct) {
388 Ok(mut input) => {
389 assert!(derive_avro_schema(&mut input).is_ok())
390 }
391 Err(error) => panic!(
392 "Failed to parse as derive input when it should be able to. Error: {error:?}"
393 ),
394 };
395 }
396
397 #[test]
398 fn tuple_struct_unsupported() {
399 let test_tuple_struct = quote! {
400 struct B (i32, String);
401 };
402
403 match syn::parse2::<DeriveInput>(test_tuple_struct) {
404 Ok(mut input) => {
405 assert!(derive_avro_schema(&mut input).is_err())
406 }
407 Err(error) => panic!(
408 "Failed to parse as derive input when it should be able to. Error: {error:?}"
409 ),
410 };
411 }
412
413 #[test]
414 fn unit_struct_unsupported() {
415 let test_tuple_struct = quote! {
416 struct AbsoluteUnit;
417 };
418
419 match syn::parse2::<DeriveInput>(test_tuple_struct) {
420 Ok(mut input) => {
421 assert!(derive_avro_schema(&mut input).is_err())
422 }
423 Err(error) => panic!(
424 "Failed to parse as derive input when it should be able to. Error: {error:?}"
425 ),
426 };
427 }
428
429 #[test]
430 fn struct_with_optional() {
431 let struct_with_optional = quote! {
432 struct Test4 {
433 a : Option<i32>
434 }
435 };
436 match syn::parse2::<DeriveInput>(struct_with_optional) {
437 Ok(mut input) => {
438 assert!(derive_avro_schema(&mut input).is_ok())
439 }
440 Err(error) => panic!(
441 "Failed to parse as derive input when it should be able to. Error: {error:?}"
442 ),
443 };
444 }
445
446 #[test]
447 fn test_basic_enum() {
448 let basic_enum = quote! {
449 enum Basic {
450 A,
451 B,
452 C,
453 D
454 }
455 };
456 match syn::parse2::<DeriveInput>(basic_enum) {
457 Ok(mut input) => {
458 assert!(derive_avro_schema(&mut input).is_ok())
459 }
460 Err(error) => panic!(
461 "Failed to parse as derive input when it should be able to. Error: {error:?}"
462 ),
463 };
464 }
465
466 #[test]
467 fn avro_3687_basic_enum_with_default() {
468 let basic_enum = quote! {
469 enum Basic {
470 #[default]
471 A,
472 B,
473 C,
474 D
475 }
476 };
477 match syn::parse2::<DeriveInput>(basic_enum) {
478 Ok(mut input) => {
479 let derived = derive_avro_schema(&mut input);
480 assert!(derived.is_ok());
481 assert_eq!(derived.unwrap().to_string(), quote! {
482 impl apache_avro::schema::derive::AvroSchemaComponent for Basic {
483 fn get_schema_in_ctxt(
484 named_schemas: &mut std::collections::HashMap<
485 apache_avro::schema::Name,
486 apache_avro::schema::Schema
487 >,
488 enclosing_namespace: &Option<String>
489 ) -> apache_avro::schema::Schema {
490 let name = apache_avro::schema::Name::new("Basic")
491 .expect(&format!("Unable to parse schema name {}", "Basic")[..])
492 .fully_qualified_name(enclosing_namespace);
493 let enclosing_namespace = &name.namespace;
494 if named_schemas.contains_key(&name) {
495 apache_avro::schema::Schema::Ref { name: name.clone() }
496 } else {
497 named_schemas.insert(
498 name.clone(),
499 apache_avro::schema::Schema::Ref { name: name.clone() }
500 );
501 apache_avro::schema::Schema::Enum(apache_avro::schema::EnumSchema {
502 name: apache_avro::schema::Name::new("Basic").expect(
503 &format!("Unable to parse enum name for schema {}", "Basic")[..]
504 ),
505 aliases: None,
506 doc: None,
507 symbols: vec![
508 "A".to_owned(),
509 "B".to_owned(),
510 "C".to_owned(),
511 "D".to_owned()
512 ],
513 default: Some("A".into()),
514 attributes: Default::default(),
515 })
516 }
517 }
518 }
519 }.to_string());
520 }
521 Err(error) => panic!(
522 "Failed to parse as derive input when it should be able to. Error: {error:?}"
523 ),
524 };
525 }
526
527 #[test]
528 fn avro_3687_basic_enum_with_default_twice() {
529 let non_basic_enum = quote! {
530 enum Basic {
531 #[default]
532 A,
533 B,
534 #[default]
535 C,
536 D
537 }
538 };
539 match syn::parse2::<DeriveInput>(non_basic_enum) {
540 Ok(mut input) => match derive_avro_schema(&mut input) {
541 Ok(_) => {
542 panic!("Should not be able to derive schema for enum with multiple defaults")
543 }
544 Err(errors) => {
545 assert_eq!(errors.len(), 1);
546 assert_eq!(
547 errors[0].to_string(),
548 r#"Multiple defaults defined: ["A", "C"]"#
549 );
550 }
551 },
552 Err(error) => panic!(
553 "Failed to parse as derive input when it should be able to. Error: {error:?}"
554 ),
555 };
556 }
557
558 #[test]
559 fn test_non_basic_enum() {
560 let non_basic_enum = quote! {
561 enum Basic {
562 A(i32),
563 B,
564 C,
565 D
566 }
567 };
568 match syn::parse2::<DeriveInput>(non_basic_enum) {
569 Ok(mut input) => {
570 assert!(derive_avro_schema(&mut input).is_err())
571 }
572 Err(error) => panic!(
573 "Failed to parse as derive input when it should be able to. Error: {error:?}"
574 ),
575 };
576 }
577
578 #[test]
579 fn test_namespace() {
580 let test_struct = quote! {
581 #[avro(namespace = "namespace.testing")]
582 struct A {
583 a: i32,
584 b: String
585 }
586 };
587
588 match syn::parse2::<DeriveInput>(test_struct) {
589 Ok(mut input) => {
590 let schema_token_stream = derive_avro_schema(&mut input);
591 assert!(&schema_token_stream.is_ok());
592 assert!(schema_token_stream
593 .unwrap()
594 .to_string()
595 .contains("namespace.testing"))
596 }
597 Err(error) => panic!(
598 "Failed to parse as derive input when it should be able to. Error: {error:?}"
599 ),
600 };
601 }
602
603 #[test]
604 fn test_reference() {
605 let test_reference_struct = quote! {
606 struct A<'a> {
607 a: &'a Vec<i32>,
608 b: &'static str
609 }
610 };
611
612 match syn::parse2::<DeriveInput>(test_reference_struct) {
613 Ok(mut input) => {
614 assert!(derive_avro_schema(&mut input).is_ok())
615 }
616 Err(error) => panic!(
617 "Failed to parse as derive input when it should be able to. Error: {error:?}"
618 ),
619 };
620 }
621
622 #[test]
623 fn test_trait_cast() {
624 assert_eq!(type_path_schema_expr(&syn::parse2::<TypePath>(quote!{i32}).unwrap()).to_string(), quote!{<i32 as apache_avro::schema::derive::AvroSchemaComponent>::get_schema_in_ctxt(named_schemas, enclosing_namespace)}.to_string());
625 assert_eq!(type_path_schema_expr(&syn::parse2::<TypePath>(quote!{Vec<T>}).unwrap()).to_string(), quote!{<Vec<T> as apache_avro::schema::derive::AvroSchemaComponent>::get_schema_in_ctxt(named_schemas, enclosing_namespace)}.to_string());
626 assert_eq!(type_path_schema_expr(&syn::parse2::<TypePath>(quote!{AnyType}).unwrap()).to_string(), quote!{<AnyType as apache_avro::schema::derive::AvroSchemaComponent>::get_schema_in_ctxt(named_schemas, enclosing_namespace)}.to_string());
627 }
628
629 #[test]
630 fn test_avro_3709_record_field_attributes() {
631 let test_struct = quote! {
632 struct A {
633 #[avro(alias = "a1", alias = "a2", doc = "a doc", default = "123", rename = "a3")]
634 a: i32
635 }
636 };
637
638 match syn::parse2::<DeriveInput>(test_struct) {
639 Ok(mut input) => {
640 let schema_res = derive_avro_schema(&mut input);
641 let expected_token_stream = r#"let schema_fields = vec ! [apache_avro :: schema :: RecordField { name : "a3" . to_string () , doc : Some ("a doc" . into ()) , default : Some (serde_json :: from_str ("123") . expect (format ! ("Invalid JSON: {:?}" , "123") . as_str ())) , aliases : Some (vec ! ["a1" . into () , "a2" . into ()]) , schema : apache_avro :: schema :: Schema :: Int , order : apache_avro :: schema :: RecordFieldOrder :: Ascending , position : 0usize , custom_attributes : Default :: default () , }] ;"#;
642 let schema_token_stream = schema_res.unwrap().to_string();
643 assert!(schema_token_stream.contains(expected_token_stream));
644 }
645 Err(error) => panic!(
646 "Failed to parse as derive input when it should be able to. Error: {error:?}"
647 ),
648 };
649 }
650}