1#![cfg_attr(nightly, feature(proc_macro_diagnostic))]
19
20mod attributes;
33mod case;
34mod enums;
35
36use proc_macro2::{Span, TokenStream};
37use quote::quote;
38use syn::{
39 DataStruct, DeriveInput, Expr, Field, Fields, Generics, Ident, Type, parse_macro_input,
40 spanned::Spanned,
41};
42
43use crate::enums::get_data_enum_schema_def;
44use crate::{
45 attributes::{FieldDefault, FieldOptions, NamedTypeOptions, With},
46 case::RenameRule,
47};
48
49#[proc_macro_derive(AvroSchema, attributes(avro, serde))]
50pub fn proc_macro_derive_avro_schema(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
52 let input = parse_macro_input!(input as DeriveInput);
53 derive_avro_schema(input)
54 .unwrap_or_else(to_compile_errors)
55 .into()
56}
57
58fn derive_avro_schema(input: DeriveInput) -> Result<TokenStream, Vec<syn::Error>> {
59 let input_span = input.span();
62 match input.data {
63 syn::Data::Struct(data_struct) => {
64 let named_type_options = NamedTypeOptions::new(&input.ident, &input.attrs, input_span)?;
65 let (get_schema_impl, get_record_fields_impl) = if named_type_options.transparent {
66 get_transparent_struct_schema_def(data_struct.fields, input_span)?
67 } else {
68 let (schema_def, record_fields) =
69 get_struct_schema_def(&named_type_options, data_struct, input.ident.span())?;
70 (
71 handle_named_schemas(named_type_options.name, schema_def),
72 record_fields,
73 )
74 };
75 Ok(create_trait_definition(
76 input.ident,
77 &input.generics,
78 get_schema_impl,
79 get_record_fields_impl,
80 named_type_options.default,
81 ))
82 }
83 syn::Data::Enum(data_enum) => {
84 let named_type_options = NamedTypeOptions::new(&input.ident, &input.attrs, input_span)?;
85 if named_type_options.transparent {
86 return Err(vec![syn::Error::new(
87 input_span,
88 "AvroSchema: `#[serde(transparent)]` is only supported on structs",
89 )]);
90 }
91 let schema_def =
92 get_data_enum_schema_def(&named_type_options, data_enum, input.ident.span())?;
93 let inner = handle_named_schemas(named_type_options.name, schema_def);
94 Ok(create_trait_definition(
95 input.ident,
96 &input.generics,
97 inner,
98 quote! { ::std::option::Option::None },
99 named_type_options.default,
100 ))
101 }
102 syn::Data::Union(_) => Err(vec![syn::Error::new(
103 input_span,
104 "AvroSchema: derive only works for structs and simple enums",
105 )]),
106 }
107}
108
109fn create_trait_definition(
111 ident: Ident,
112 generics: &Generics,
113 get_schema_impl: TokenStream,
114 get_record_fields_impl: TokenStream,
115 field_default_impl: TokenStream,
116) -> TokenStream {
117 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
118 quote! {
119 #[automatically_derived]
120 impl #impl_generics ::apache_avro::AvroSchemaComponent for #ident #ty_generics #where_clause {
121 fn get_schema_in_ctxt(named_schemas: &mut ::std::collections::HashSet<::apache_avro::schema::Name>, enclosing_namespace: ::apache_avro::schema::NamespaceRef) -> ::apache_avro::schema::Schema {
122 #get_schema_impl
123 }
124
125 fn get_record_fields_in_ctxt(named_schemas: &mut ::std::collections::HashSet<::apache_avro::schema::Name>, enclosing_namespace: ::apache_avro::schema::NamespaceRef) -> ::std::option::Option<::std::vec::Vec<::apache_avro::schema::RecordField>> {
126 #get_record_fields_impl
127 }
128
129 fn field_default() -> ::std::option::Option<::serde_json::Value> {
130 ::std::option::Option::#field_default_impl
131 }
132 }
133 }
134}
135
136fn handle_named_schemas(full_schema_name: String, schema_def: TokenStream) -> TokenStream {
138 quote! {
139 let name = ::apache_avro::schema::Name::new_with_enclosing_namespace(#full_schema_name, enclosing_namespace).expect(concat!("Unable to parse schema name ", #full_schema_name));
140 if named_schemas.contains(&name) {
141 ::apache_avro::schema::Schema::Ref{name}
142 } else {
143 let enclosing_namespace = name.namespace();
144 named_schemas.insert(name.clone());
145 #schema_def
146 }
147 }
148}
149
150fn get_struct_schema_def(
152 container_attrs: &NamedTypeOptions,
153 data_struct: DataStruct,
154 ident_span: Span,
155) -> Result<(TokenStream, TokenStream), Vec<syn::Error>> {
156 let mut record_field_exprs = vec![];
157 match data_struct.fields {
158 Fields::Named(a) => {
159 for field in a.named {
160 let mut name = field
161 .ident
162 .as_ref()
163 .expect("Field must have a name")
164 .to_string();
165 if let Some(raw_name) = name.strip_prefix("r#") {
166 name = raw_name.to_string();
167 }
168 let field_attrs = FieldOptions::new(&field.attrs, field.span())?;
169 let doc = preserve_optional(field_attrs.doc);
170 match (field_attrs.rename, container_attrs.rename_all) {
171 (Some(rename), _) => {
172 name = rename;
173 }
174 (None, rename_all) if rename_all != RenameRule::None => {
175 name = rename_all.apply_to_field(&name);
176 }
177 _ => {}
178 }
179 if field_attrs.skip {
180 continue;
181 } else if field_attrs.flatten {
182 let get_record_fields =
185 get_field_get_record_fields_expr(&field, field_attrs.with)?;
186 record_field_exprs.push(quote! {
187 if let Some(flattened_fields) = #get_record_fields {
188 schema_fields.extend(flattened_fields);
189 } else {
190 panic!("{} does not have any fields to flatten to", stringify!(#field));
191 }
192 });
193
194 continue;
196 }
197 let default_value = match field_attrs.default {
198 FieldDefault::Disabled => quote! { ::std::option::Option::None },
199 FieldDefault::Trait => type_to_field_default_expr(&field.ty)?,
200 FieldDefault::Value(default_value) => {
201 let _: serde_json::Value = serde_json::from_str(&default_value[..])
202 .map_err(|e| {
203 vec![syn::Error::new(
204 field.ident.span(),
205 format!("Invalid avro default json: \n{e}"),
206 )]
207 })?;
208 quote! {
209 ::std::option::Option::Some(::serde_json::from_str(#default_value).expect("Unreachable! This parsed at compile time!"))
210 }
211 }
212 };
213 let aliases = field_aliases(&field_attrs.alias);
214 let schema_expr = get_field_schema_expr(&field, field_attrs.with)?;
215 record_field_exprs.push(quote! {
216 schema_fields.push(::apache_avro::schema::RecordField {
217 name: #name.to_string(),
218 doc: #doc,
219 default: #default_value,
220 aliases: #aliases,
221 schema: #schema_expr,
222 custom_attributes: ::std::collections::BTreeMap::new(),
223 });
224 });
225 }
226 }
227 Fields::Unnamed(_) => {
228 return Err(vec![syn::Error::new(
229 ident_span,
230 "AvroSchema derive does not work for tuple structs",
231 )]);
232 }
233 Fields::Unit => {
234 return Err(vec![syn::Error::new(
235 ident_span,
236 "AvroSchema derive does not work for unit structs",
237 )]);
238 }
239 }
240
241 let record_doc = preserve_optional(container_attrs.doc.as_ref());
242 let record_aliases = aliases(&container_attrs.aliases);
243 let full_schema_name = &container_attrs.name;
244
245 let minimum_fields = record_field_exprs.len();
248
249 let schema_def = quote! {
250 {
251 let mut schema_fields = ::std::vec::Vec::with_capacity(#minimum_fields);
252 #(#record_field_exprs)*
253 let schema_field_set: ::std::collections::HashSet<_> = schema_fields.iter().map(|rf| &rf.name).collect();
254 assert_eq!(schema_fields.len(), schema_field_set.len(), "Duplicate field names found: {schema_fields:?}");
255 let name = ::apache_avro::schema::Name::new(#full_schema_name).expect(&format!("Unable to parse struct name for schema {}", #full_schema_name)[..]);
256 let lookup: ::std::collections::BTreeMap<String, usize> = schema_fields
257 .iter()
258 .enumerate()
259 .map(|(position, field)| (field.name.to_owned(), position))
260 .collect();
261 ::apache_avro::schema::Schema::Record(::apache_avro::schema::RecordSchema {
262 name,
263 aliases: #record_aliases,
264 doc: #record_doc,
265 fields: schema_fields,
266 lookup,
267 attributes: ::std::collections::BTreeMap::new(),
268 })
269 }
270 };
271 let record_fields = quote! {
272 let mut schema_fields = ::std::vec::Vec::with_capacity(#minimum_fields);
273 #(#record_field_exprs)*
274 ::std::option::Option::Some(schema_fields)
275 };
276
277 Ok((schema_def, record_fields))
278}
279
280fn get_transparent_struct_schema_def(
282 fields: Fields,
283 input_span: Span,
284) -> Result<(TokenStream, TokenStream), Vec<syn::Error>> {
285 match fields {
286 Fields::Named(fields_named) => {
287 let mut found = None;
288 for field in fields_named.named {
289 let attrs = FieldOptions::new(&field.attrs, field.span())?;
290 if attrs.skip {
291 continue;
292 }
293 if found.replace((field, attrs)).is_some() {
294 return Err(vec![syn::Error::new(
295 input_span,
296 "AvroSchema: #[serde(transparent)] is only allowed on structs with one unskipped field",
297 )]);
298 }
299 }
300
301 if let Some((field, attrs)) = found {
302 Ok((
303 get_field_schema_expr(&field, attrs.with.clone())?,
304 get_field_get_record_fields_expr(&field, attrs.with)?,
305 ))
306 } else {
307 Err(vec![syn::Error::new(
308 input_span,
309 "AvroSchema: #[serde(transparent)] is only allowed on structs with one unskipped field",
310 )])
311 }
312 }
313 Fields::Unnamed(_) => Err(vec![syn::Error::new(
314 input_span,
315 "AvroSchema: derive does not work for tuple structs",
316 )]),
317 Fields::Unit => Err(vec![syn::Error::new(
318 input_span,
319 "AvroSchema: derive does not work for unit structs",
320 )]),
321 }
322}
323
324fn get_field_schema_expr(field: &Field, with: With) -> Result<TokenStream, Vec<syn::Error>> {
325 match with {
326 With::Trait => Ok(type_to_schema_expr(&field.ty)?),
327 With::Serde(path) => {
328 Ok(quote! { #path::get_schema_in_ctxt(named_schemas, enclosing_namespace) })
329 }
330 With::Expr(Expr::Closure(closure)) => {
331 if closure.inputs.is_empty() {
332 Ok(quote! { (#closure)() })
333 } else {
334 Err(vec![syn::Error::new(
335 field.span(),
336 "Expected closure with 0 parameters",
337 )])
338 }
339 }
340 With::Expr(Expr::Path(path)) => Ok(quote! { #path(named_schemas, enclosing_namespace) }),
341 With::Expr(_expr) => Err(vec![syn::Error::new(
342 field.span(),
343 "Invalid expression, expected function or closure",
344 )]),
345 }
346}
347
348fn get_field_get_record_fields_expr(
349 field: &Field,
350 with: With,
351) -> Result<TokenStream, Vec<syn::Error>> {
352 match with {
353 With::Trait => Ok(type_to_get_record_fields_expr(&field.ty)?),
354 With::Serde(path) => {
355 Ok(quote! { #path::get_record_fields_in_ctxt(named_schemas, enclosing_namespace) })
356 }
357 With::Expr(Expr::Closure(closure)) => {
358 if closure.inputs.is_empty() {
359 Ok(quote! {
360 ::apache_avro::serde::get_record_fields_in_ctxt(
361 named_schemas,
362 enclosing_namespace,
363 |_, _| (#closure)(),
364 )
365 })
366 } else {
367 Err(vec![syn::Error::new(
368 field.span(),
369 "Expected closure with 0 parameters",
370 )])
371 }
372 }
373 With::Expr(Expr::Path(path)) => Ok(quote! {
374 ::apache_avro::serde::get_record_fields_in_ctxt(named_schemas, enclosing_namespace, #path)
375 }),
376 With::Expr(_expr) => Err(vec![syn::Error::new(
377 field.span(),
378 "Invalid expression, expected function or closure",
379 )]),
380 }
381}
382
383fn type_to_schema_expr(ty: &Type) -> Result<TokenStream, Vec<syn::Error>> {
385 match ty {
386 Type::Array(_) | Type::Slice(_) | Type::Path(_) | Type::Reference(_) => Ok(
387 quote! {<#ty as :: apache_avro::AvroSchemaComponent>::get_schema_in_ctxt(named_schemas, enclosing_namespace)},
388 ),
389 Type::Ptr(_) => Err(vec![syn::Error::new_spanned(
390 ty,
391 "AvroSchema: derive does not support raw pointers",
392 )]),
393 Type::Tuple(_) => Err(vec![syn::Error::new_spanned(
394 ty,
395 "AvroSchema: derive does not support tuples",
396 )]),
397 _ => Err(vec![syn::Error::new_spanned(
398 ty,
399 format!(
400 "AvroSchema: Unexpected type encountered! Please open an issue if this kind of type should be supported: {ty:?}"
401 ),
402 )]),
403 }
404}
405
406fn type_to_get_record_fields_expr(ty: &Type) -> Result<TokenStream, Vec<syn::Error>> {
407 match ty {
408 Type::Array(_) | Type::Slice(_) | Type::Path(_) | Type::Reference(_) => Ok(
409 quote! {<#ty as :: apache_avro::AvroSchemaComponent>::get_record_fields_in_ctxt(named_schemas, enclosing_namespace)},
410 ),
411 Type::Ptr(_) => Err(vec![syn::Error::new_spanned(
412 ty,
413 "AvroSchema: derive does not support raw pointers",
414 )]),
415 Type::Tuple(_) => Err(vec![syn::Error::new_spanned(
416 ty,
417 "AvroSchema: derive does not support tuples",
418 )]),
419 _ => Err(vec![syn::Error::new_spanned(
420 ty,
421 format!(
422 "AvroSchema: Unexpected type encountered! Please open an issue if this kind of type should be supported: {ty:?}"
423 ),
424 )]),
425 }
426}
427
428fn type_to_field_default_expr(ty: &Type) -> Result<TokenStream, Vec<syn::Error>> {
429 match ty {
430 Type::Array(_) | Type::Slice(_) | Type::Path(_) | Type::Reference(_) => {
431 Ok(quote! {<#ty as :: apache_avro::AvroSchemaComponent>::field_default()})
432 }
433 Type::Ptr(_) => Err(vec![syn::Error::new_spanned(
434 ty,
435 "AvroSchema: derive does not support raw pointers",
436 )]),
437 Type::Tuple(_) => Err(vec![syn::Error::new_spanned(
438 ty,
439 "AvroSchema: derive does not support tuples",
440 )]),
441 _ => Err(vec![syn::Error::new_spanned(
442 ty,
443 format!(
444 "AvroSchema: Unexpected type encountered! Please open an issue if this kind of type should be supported: {ty:?}"
445 ),
446 )]),
447 }
448}
449
450fn to_compile_errors(errors: Vec<syn::Error>) -> proc_macro2::TokenStream {
452 let compile_errors = errors.iter().map(syn::Error::to_compile_error);
453 quote!(#(#compile_errors)*)
454}
455
456fn preserve_optional(op: Option<impl quote::ToTokens>) -> TokenStream {
457 match op {
458 Some(tt) => quote! {::std::option::Option::Some(#tt.into())},
459 None => quote! {::std::option::Option::None},
460 }
461}
462
463fn aliases(op: &[impl quote::ToTokens]) -> TokenStream {
464 let items: Vec<TokenStream> = op
465 .iter()
466 .map(|tt| quote! {#tt.try_into().expect("Alias is invalid")})
467 .collect();
468 if items.is_empty() {
469 quote! {::std::option::Option::None}
470 } else {
471 quote! {::std::option::Option::Some(vec![#(#items),*])}
472 }
473}
474
475fn field_aliases(op: &[impl quote::ToTokens]) -> TokenStream {
476 let items: Vec<TokenStream> = op
477 .iter()
478 .map(|tt| quote! {#tt.try_into().expect("Alias is invalid")})
479 .collect();
480 if items.is_empty() {
481 quote! {::std::vec::Vec::new()}
482 } else {
483 quote! {vec![#(#items),*]}
484 }
485}
486
487#[cfg(test)]
488mod tests {
489 use super::*;
490 use pretty_assertions::assert_eq;
491
492 #[test]
493 fn test_trait_cast() {
494 assert_eq!(type_to_schema_expr(&syn::parse2::<Type>(quote!{i32}).unwrap()).unwrap().to_string(), quote!{<i32 as :: apache_avro::AvroSchemaComponent>::get_schema_in_ctxt(named_schemas, enclosing_namespace)}.to_string());
495 assert_eq!(type_to_schema_expr(&syn::parse2::<Type>(quote!{Vec<T>}).unwrap()).unwrap().to_string(), quote!{<Vec<T> as :: apache_avro::AvroSchemaComponent>::get_schema_in_ctxt(named_schemas, enclosing_namespace)}.to_string());
496 assert_eq!(type_to_schema_expr(&syn::parse2::<Type>(quote!{AnyType}).unwrap()).unwrap().to_string(), quote!{<AnyType as :: apache_avro::AvroSchemaComponent>::get_schema_in_ctxt(named_schemas, enclosing_namespace)}.to_string());
497 }
498}