thiserror_impl/
expand.rs

1use crate::ast::{Enum, Field, Input, Struct};
2use crate::attr::Trait;
3use crate::fallback;
4use crate::generics::InferredBounds;
5use crate::private;
6use crate::unraw::MemberUnraw;
7use proc_macro2::{Ident, Span, TokenStream};
8use quote::{format_ident, quote, quote_spanned, ToTokens};
9use std::collections::BTreeSet as Set;
10use syn::{DeriveInput, GenericArgument, PathArguments, Result, Token, Type};
11
12pub fn derive(input: &DeriveInput) -> TokenStream {
13    match try_expand(input) {
14        Ok(expanded) => expanded,
15        // If there are invalid attributes in the input, expand to an Error impl
16        // anyway to minimize spurious secondary errors in other code that uses
17        // this type as an Error.
18        Err(error) => fallback::expand(input, error),
19    }
20}
21
22fn try_expand(input: &DeriveInput) -> Result<TokenStream> {
23    let input = Input::from_syn(input)?;
24    input.validate()?;
25    Ok(match input {
26        Input::Struct(input) => impl_struct(input),
27        Input::Enum(input) => impl_enum(input),
28    })
29}
30
31fn impl_struct(input: Struct) -> TokenStream {
32    let ty = call_site_ident(&input.ident);
33    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
34    let mut error_inferred_bounds = InferredBounds::new();
35
36    let source_body = if let Some(transparent_attr) = &input.attrs.transparent {
37        let only_field = &input.fields[0];
38        if only_field.contains_generic {
39            error_inferred_bounds.insert(only_field.ty, quote!(::thiserror::#private::Error));
40        }
41        let member = &only_field.member;
42        Some(quote_spanned! {transparent_attr.span=>
43            ::thiserror::#private::Error::source(self.#member.as_dyn_error())
44        })
45    } else if let Some(source_field) = input.source_field() {
46        let source = &source_field.member;
47        if source_field.contains_generic {
48            let ty = unoptional_type(source_field.ty);
49            error_inferred_bounds.insert(ty, quote!(::thiserror::#private::Error + 'static));
50        }
51        let asref = if type_is_option(source_field.ty) {
52            Some(quote_spanned!(source.span()=> .as_ref()?))
53        } else {
54            None
55        };
56        let dyn_error = quote_spanned! {source_field.source_span()=>
57            self.#source #asref.as_dyn_error()
58        };
59        Some(quote! {
60            ::core::option::Option::Some(#dyn_error)
61        })
62    } else {
63        None
64    };
65    let source_method = source_body.map(|body| {
66        quote! {
67            fn source(&self) -> ::core::option::Option<&(dyn ::thiserror::#private::Error + 'static)> {
68                use ::thiserror::#private::AsDynError as _;
69                #body
70            }
71        }
72    });
73
74    let provide_method = input.backtrace_field().map(|backtrace_field| {
75        let request = quote!(request);
76        let backtrace = &backtrace_field.member;
77        let body = if let Some(source_field) = input.source_field() {
78            let source = &source_field.member;
79            let source_provide = if type_is_option(source_field.ty) {
80                quote_spanned! {source.span()=>
81                    if let ::core::option::Option::Some(source) = &self.#source {
82                        source.thiserror_provide(#request);
83                    }
84                }
85            } else {
86                quote_spanned! {source.span()=>
87                    self.#source.thiserror_provide(#request);
88                }
89            };
90            let self_provide = if source == backtrace {
91                None
92            } else if type_is_option(backtrace_field.ty) {
93                Some(quote! {
94                    if let ::core::option::Option::Some(backtrace) = &self.#backtrace {
95                        #request.provide_ref::<::thiserror::#private::Backtrace>(backtrace);
96                    }
97                })
98            } else {
99                Some(quote! {
100                    #request.provide_ref::<::thiserror::#private::Backtrace>(&self.#backtrace);
101                })
102            };
103            quote! {
104                use ::thiserror::#private::ThiserrorProvide as _;
105                #source_provide
106                #self_provide
107            }
108        } else if type_is_option(backtrace_field.ty) {
109            quote! {
110                if let ::core::option::Option::Some(backtrace) = &self.#backtrace {
111                    #request.provide_ref::<::thiserror::#private::Backtrace>(backtrace);
112                }
113            }
114        } else {
115            quote! {
116                #request.provide_ref::<::thiserror::#private::Backtrace>(&self.#backtrace);
117            }
118        };
119        quote! {
120            fn provide<'_request>(&'_request self, #request: &mut ::core::error::Request<'_request>) {
121                #body
122            }
123        }
124    });
125
126    let mut display_implied_bounds = Set::new();
127    let display_body = if input.attrs.transparent.is_some() {
128        let only_field = &input.fields[0].member;
129        display_implied_bounds.insert((0, Trait::Display));
130        Some(quote! {
131            ::core::fmt::Display::fmt(&self.#only_field, __formatter)
132        })
133    } else if let Some(display) = &input.attrs.display {
134        display_implied_bounds.clone_from(&display.implied_bounds);
135        let use_as_display = use_as_display(display.has_bonus_display);
136        let pat = fields_pat(&input.fields);
137        Some(quote! {
138            #use_as_display
139            #[allow(unused_variables, deprecated)]
140            let Self #pat = self;
141            #display
142        })
143    } else {
144        None
145    };
146    let display_impl = display_body.map(|body| {
147        let mut display_inferred_bounds = InferredBounds::new();
148        for (field, bound) in display_implied_bounds {
149            let field = &input.fields[field];
150            if field.contains_generic {
151                display_inferred_bounds.insert(field.ty, bound);
152            }
153        }
154        let display_where_clause = display_inferred_bounds.augment_where_clause(input.generics);
155        quote! {
156            #[allow(unused_qualifications)]
157            #[automatically_derived]
158            impl #impl_generics ::core::fmt::Display for #ty #ty_generics #display_where_clause {
159                #[allow(clippy::used_underscore_binding)]
160                fn fmt(&self, __formatter: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
161                    #body
162                }
163            }
164        }
165    });
166
167    let from_impl = input.from_field().map(|from_field| {
168        let span = from_field.attrs.from.unwrap().span;
169        let backtrace_field = input.distinct_backtrace_field();
170        let from = unoptional_type(from_field.ty);
171        let source_var = Ident::new("source", span);
172        let body = from_initializer(from_field, backtrace_field, &source_var);
173        let from_function = quote! {
174            fn from(#source_var: #from) -> Self {
175                #ty #body
176            }
177        };
178        let from_impl = quote_spanned! {span=>
179            #[automatically_derived]
180            impl #impl_generics ::core::convert::From<#from> for #ty #ty_generics #where_clause {
181                #from_function
182            }
183        };
184        let lint_allows = if input.generics.lifetimes().next().is_some() {
185            Some(quote! {
186                clippy::elidable_lifetime_names,
187                clippy::needless_lifetimes,
188            })
189        } else {
190            None
191        };
192        Some(quote! {
193            #[allow(
194                deprecated,
195                unused_qualifications,
196                #lint_allows
197            )]
198            #from_impl
199        })
200    });
201
202    if input.generics.type_params().next().is_some() {
203        let self_token = <Token![Self]>::default();
204        error_inferred_bounds.insert(self_token, Trait::Debug);
205        error_inferred_bounds.insert(self_token, Trait::Display);
206    }
207    let error_where_clause = error_inferred_bounds.augment_where_clause(input.generics);
208
209    quote! {
210        #[allow(unused_qualifications)]
211        #[automatically_derived]
212        impl #impl_generics ::thiserror::#private::Error for #ty #ty_generics #error_where_clause {
213            #source_method
214            #provide_method
215        }
216        #display_impl
217        #from_impl
218    }
219}
220
221fn impl_enum(input: Enum) -> TokenStream {
222    let ty = call_site_ident(&input.ident);
223    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
224    let mut error_inferred_bounds = InferredBounds::new();
225
226    let source_method = if input.has_source() {
227        let arms = input.variants.iter().map(|variant| {
228            let ident = &variant.ident;
229            if let Some(transparent_attr) = &variant.attrs.transparent {
230                let only_field = &variant.fields[0];
231                if only_field.contains_generic {
232                    error_inferred_bounds.insert(only_field.ty, quote!(::thiserror::#private::Error));
233                }
234                let member = &only_field.member;
235                let source = quote_spanned! {transparent_attr.span=>
236                    ::thiserror::#private::Error::source(transparent.as_dyn_error())
237                };
238                quote! {
239                    #ty::#ident {#member: transparent} => #source,
240                }
241            } else if let Some(source_field) = variant.source_field() {
242                let source = &source_field.member;
243                if source_field.contains_generic {
244                    let ty = unoptional_type(source_field.ty);
245                    error_inferred_bounds.insert(ty, quote!(::thiserror::#private::Error + 'static));
246                }
247                let asref = if type_is_option(source_field.ty) {
248                    Some(quote_spanned!(source.span()=> .as_ref()?))
249                } else {
250                    None
251                };
252                let varsource = quote!(source);
253                let dyn_error = quote_spanned! {source_field.source_span()=>
254                    #varsource #asref.as_dyn_error()
255                };
256                quote! {
257                    #ty::#ident {#source: #varsource, ..} => ::core::option::Option::Some(#dyn_error),
258                }
259            } else {
260                quote! {
261                    #ty::#ident {..} => ::core::option::Option::None,
262                }
263            }
264        });
265        Some(quote! {
266            fn source(&self) -> ::core::option::Option<&(dyn ::thiserror::#private::Error + 'static)> {
267                use ::thiserror::#private::AsDynError as _;
268                #[allow(deprecated)]
269                match self {
270                    #(#arms)*
271                }
272            }
273        })
274    } else {
275        None
276    };
277
278    let provide_method = if input.has_backtrace() {
279        let request = quote!(request);
280        let arms = input.variants.iter().map(|variant| {
281            let ident = &variant.ident;
282            match (variant.backtrace_field(), variant.source_field()) {
283                (Some(backtrace_field), Some(source_field))
284                    if backtrace_field.attrs.backtrace.is_none() =>
285                {
286                    let backtrace = &backtrace_field.member;
287                    let source = &source_field.member;
288                    let varsource = quote!(source);
289                    let source_provide = if type_is_option(source_field.ty) {
290                        quote_spanned! {source.span()=>
291                            if let ::core::option::Option::Some(source) = #varsource {
292                                source.thiserror_provide(#request);
293                            }
294                        }
295                    } else {
296                        quote_spanned! {source.span()=>
297                            #varsource.thiserror_provide(#request);
298                        }
299                    };
300                    let self_provide = if type_is_option(backtrace_field.ty) {
301                        quote! {
302                            if let ::core::option::Option::Some(backtrace) = backtrace {
303                                #request.provide_ref::<::thiserror::#private::Backtrace>(backtrace);
304                            }
305                        }
306                    } else {
307                        quote! {
308                            #request.provide_ref::<::thiserror::#private::Backtrace>(backtrace);
309                        }
310                    };
311                    quote! {
312                        #ty::#ident {
313                            #backtrace: backtrace,
314                            #source: #varsource,
315                            ..
316                        } => {
317                            use ::thiserror::#private::ThiserrorProvide as _;
318                            #source_provide
319                            #self_provide
320                        }
321                    }
322                }
323                (Some(backtrace_field), Some(source_field))
324                    if backtrace_field.member == source_field.member =>
325                {
326                    let backtrace = &backtrace_field.member;
327                    let varsource = quote!(source);
328                    let source_provide = if type_is_option(source_field.ty) {
329                        quote_spanned! {backtrace.span()=>
330                            if let ::core::option::Option::Some(source) = #varsource {
331                                source.thiserror_provide(#request);
332                            }
333                        }
334                    } else {
335                        quote_spanned! {backtrace.span()=>
336                            #varsource.thiserror_provide(#request);
337                        }
338                    };
339                    quote! {
340                        #ty::#ident {#backtrace: #varsource, ..} => {
341                            use ::thiserror::#private::ThiserrorProvide as _;
342                            #source_provide
343                        }
344                    }
345                }
346                (Some(backtrace_field), _) => {
347                    let backtrace = &backtrace_field.member;
348                    let body = if type_is_option(backtrace_field.ty) {
349                        quote! {
350                            if let ::core::option::Option::Some(backtrace) = backtrace {
351                                #request.provide_ref::<::thiserror::#private::Backtrace>(backtrace);
352                            }
353                        }
354                    } else {
355                        quote! {
356                            #request.provide_ref::<::thiserror::#private::Backtrace>(backtrace);
357                        }
358                    };
359                    quote! {
360                        #ty::#ident {#backtrace: backtrace, ..} => {
361                            #body
362                        }
363                    }
364                }
365                (None, _) => quote! {
366                    #ty::#ident {..} => {}
367                },
368            }
369        });
370        Some(quote! {
371            fn provide<'_request>(&'_request self, #request: &mut ::core::error::Request<'_request>) {
372                #[allow(deprecated)]
373                match self {
374                    #(#arms)*
375                }
376            }
377        })
378    } else {
379        None
380    };
381
382    let display_impl = if input.has_display() {
383        let mut display_inferred_bounds = InferredBounds::new();
384        let has_bonus_display = input.variants.iter().any(|v| {
385            v.attrs
386                .display
387                .as_ref()
388                .map_or(false, |display| display.has_bonus_display)
389        });
390        let use_as_display = use_as_display(has_bonus_display);
391        let void_deref = if input.variants.is_empty() {
392            Some(quote!(*))
393        } else {
394            None
395        };
396        let arms = input.variants.iter().map(|variant| {
397            let mut display_implied_bounds = Set::new();
398            let display = if let Some(display) = &variant.attrs.display {
399                display_implied_bounds.clone_from(&display.implied_bounds);
400                display.to_token_stream()
401            } else if let Some(fmt) = &variant.attrs.fmt {
402                let fmt_path = &fmt.path;
403                let vars = variant.fields.iter().map(|field| match &field.member {
404                    MemberUnraw::Named(ident) => ident.to_local(),
405                    MemberUnraw::Unnamed(index) => format_ident!("_{}", index),
406                });
407                quote!(#fmt_path(#(#vars,)* __formatter))
408            } else {
409                let only_field = match &variant.fields[0].member {
410                    MemberUnraw::Named(ident) => ident.to_local(),
411                    MemberUnraw::Unnamed(index) => format_ident!("_{}", index),
412                };
413                display_implied_bounds.insert((0, Trait::Display));
414                quote!(::core::fmt::Display::fmt(#only_field, __formatter))
415            };
416            for (field, bound) in display_implied_bounds {
417                let field = &variant.fields[field];
418                if field.contains_generic {
419                    display_inferred_bounds.insert(field.ty, bound);
420                }
421            }
422            let ident = &variant.ident;
423            let pat = fields_pat(&variant.fields);
424            quote! {
425                #ty::#ident #pat => #display
426            }
427        });
428        let arms = arms.collect::<Vec<_>>();
429        let display_where_clause = display_inferred_bounds.augment_where_clause(input.generics);
430        Some(quote! {
431            #[allow(unused_qualifications)]
432            #[automatically_derived]
433            impl #impl_generics ::core::fmt::Display for #ty #ty_generics #display_where_clause {
434                fn fmt(&self, __formatter: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
435                    #use_as_display
436                    #[allow(unused_variables, deprecated, clippy::used_underscore_binding)]
437                    match #void_deref self {
438                        #(#arms,)*
439                    }
440                }
441            }
442        })
443    } else {
444        None
445    };
446
447    let from_impls = input.variants.iter().filter_map(|variant| {
448        let from_field = variant.from_field()?;
449        let span = from_field.attrs.from.unwrap().span;
450        let backtrace_field = variant.distinct_backtrace_field();
451        let variant = &variant.ident;
452        let from = unoptional_type(from_field.ty);
453        let source_var = Ident::new("source", span);
454        let body = from_initializer(from_field, backtrace_field, &source_var);
455        let from_function = quote! {
456            fn from(#source_var: #from) -> Self {
457                #ty::#variant #body
458            }
459        };
460        let from_impl = quote_spanned! {span=>
461            #[automatically_derived]
462            impl #impl_generics ::core::convert::From<#from> for #ty #ty_generics #where_clause {
463                #from_function
464            }
465        };
466        let lint_allows = if input.generics.lifetimes().next().is_some() {
467            Some(quote! {
468                clippy::elidable_lifetime_names,
469                clippy::needless_lifetimes,
470            })
471        } else {
472            None
473        };
474        Some(quote! {
475            #[allow(
476                deprecated,
477                unused_qualifications,
478                #lint_allows
479            )]
480            #from_impl
481        })
482    });
483
484    if input.generics.type_params().next().is_some() {
485        let self_token = <Token![Self]>::default();
486        error_inferred_bounds.insert(self_token, Trait::Debug);
487        error_inferred_bounds.insert(self_token, Trait::Display);
488    }
489    let error_where_clause = error_inferred_bounds.augment_where_clause(input.generics);
490
491    quote! {
492        #[allow(unused_qualifications)]
493        #[automatically_derived]
494        impl #impl_generics ::thiserror::#private::Error for #ty #ty_generics #error_where_clause {
495            #source_method
496            #provide_method
497        }
498        #display_impl
499        #(#from_impls)*
500    }
501}
502
503// Create an ident with which we can expand `impl Trait for #ident {}` on a
504// deprecated type without triggering deprecation warning on the generated impl.
505pub(crate) fn call_site_ident(ident: &Ident) -> Ident {
506    let mut ident = ident.clone();
507    ident.set_span(ident.span().resolved_at(Span::call_site()));
508    ident
509}
510
511fn fields_pat(fields: &[Field]) -> TokenStream {
512    let mut members = fields.iter().map(|field| &field.member).peekable();
513    match members.peek() {
514        Some(MemberUnraw::Named(_)) => quote!({ #(#members),* }),
515        Some(MemberUnraw::Unnamed(_)) => {
516            let vars = members.map(|member| match member {
517                MemberUnraw::Unnamed(index) => format_ident!("_{}", index),
518                MemberUnraw::Named(_) => unreachable!(),
519            });
520            quote!((#(#vars),*))
521        }
522        None => quote!({}),
523    }
524}
525
526fn use_as_display(needs_as_display: bool) -> Option<TokenStream> {
527    if needs_as_display {
528        Some(quote! {
529            use ::thiserror::#private::AsDisplay as _;
530        })
531    } else {
532        None
533    }
534}
535
536fn from_initializer(
537    from_field: &Field,
538    backtrace_field: Option<&Field>,
539    source_var: &Ident,
540) -> TokenStream {
541    let from_member = &from_field.member;
542    let some_source = if type_is_option(from_field.ty) {
543        quote!(::core::option::Option::Some(#source_var))
544    } else {
545        quote!(#source_var)
546    };
547    let backtrace = backtrace_field.map(|backtrace_field| {
548        let backtrace_member = &backtrace_field.member;
549        if type_is_option(backtrace_field.ty) {
550            quote! {
551                #backtrace_member: ::core::option::Option::Some(::thiserror::#private::Backtrace::capture()),
552            }
553        } else {
554            quote! {
555                #backtrace_member: ::core::convert::From::from(::thiserror::#private::Backtrace::capture()),
556            }
557        }
558    });
559    quote!({
560        #from_member: #some_source,
561        #backtrace
562    })
563}
564
565fn type_is_option(ty: &Type) -> bool {
566    type_parameter_of_option(ty).is_some()
567}
568
569fn unoptional_type(ty: &Type) -> TokenStream {
570    let unoptional = type_parameter_of_option(ty).unwrap_or(ty);
571    quote!(#unoptional)
572}
573
574fn type_parameter_of_option(ty: &Type) -> Option<&Type> {
575    let path = match ty {
576        Type::Path(ty) => &ty.path,
577        _ => return None,
578    };
579
580    let last = path.segments.last().unwrap();
581    if last.ident != "Option" {
582        return None;
583    }
584
585    let bracketed = match &last.arguments {
586        PathArguments::AngleBracketed(bracketed) => bracketed,
587        _ => return None,
588    };
589
590    if bracketed.args.len() != 1 {
591        return None;
592    }
593
594    match &bracketed.args[0] {
595        GenericArgument::Type(arg) => Some(arg),
596        _ => None,
597    }
598}