diff --git a/impl/src/valid.rs b/impl/src/valid.rs index d410e52..cbd9f41 100644 --- a/impl/src/valid.rs +++ b/impl/src/valid.rs @@ -188,7 +188,7 @@ fn check_field_attrs(fields: &[Field]) -> Result<()> { } } if let Some(source_field) = source_field.or(from_field) { - if contains_non_static_lifetime(source_field) { + if contains_non_static_lifetime(&source_field.ty) { return Err(Error::new_spanned( &source_field.original.ty, "non-static lifetimes are not allowed in the source of an error, because std::error::Error requires the source is dyn Error + 'static", @@ -206,21 +206,28 @@ fn same_member(one: &Field, two: &Field) -> bool { } } -fn contains_non_static_lifetime(field: &Field) -> bool { - let ty = match field.ty { - Type::Path(ty) => ty, - _ => return false, // maybe implement later if there are common other cases - }; - let bracketed = match &ty.path.segments.last().unwrap().arguments { - PathArguments::AngleBracketed(bracketed) => bracketed, - _ => return false, - }; - for arg in &bracketed.args { - if let GenericArgument::Lifetime(lifetime) = arg { - if lifetime.ident != "static" { - return true; +fn contains_non_static_lifetime(ty: &Type) -> bool { + match ty { + Type::Path(ty) => { + let bracketed = match &ty.path.segments.last().unwrap().arguments { + PathArguments::AngleBracketed(bracketed) => bracketed, + _ => return false, + }; + for arg in &bracketed.args { + match arg { + GenericArgument::Type(ty) if contains_non_static_lifetime(ty) => return true, + GenericArgument::Lifetime(lifetime) if lifetime.ident != "static" => { + return true + } + _ => {} + } } + false } + Type::Reference(ty) => ty + .lifetime + .as_ref() + .map_or(false, |lifetime| lifetime.ident != "static"), + _ => false, // maybe implement later if there are common other cases } - false } diff --git a/tests/ui/lifetime.rs b/tests/ui/lifetime.rs index 63c6970..698f8c4 100644 --- a/tests/ui/lifetime.rs +++ b/tests/ui/lifetime.rs @@ -1,3 +1,4 @@ +use std::fmt::Debug; use thiserror::Error; #[derive(Error, Debug)] @@ -8,6 +9,16 @@ struct Error<'a>(#[from] Inner<'a>); #[error("{0}")] struct Inner<'a>(&'a str); +#[derive(Error, Debug)] +enum Enum<'a> { + #[error("error")] + Foo(#[from] Generic<&'a str>), +} + +#[derive(Error, Debug)] +#[error("{0:?}")] +struct Generic(T); + fn main() -> Result<(), Error<'static>> { Err(Error(Inner("some text"))) } diff --git a/tests/ui/lifetime.stderr b/tests/ui/lifetime.stderr index fbf21ad..5f86fa0 100644 --- a/tests/ui/lifetime.stderr +++ b/tests/ui/lifetime.stderr @@ -1,5 +1,11 @@ error: non-static lifetimes are not allowed in the source of an error, because std::error::Error requires the source is dyn Error + 'static - --> $DIR/lifetime.rs:5:26 + --> $DIR/lifetime.rs:6:26 | -5 | struct Error<'a>(#[from] Inner<'a>); +6 | struct Error<'a>(#[from] Inner<'a>); | ^^^^^^^^^ + +error: non-static lifetimes are not allowed in the source of an error, because std::error::Error requires the source is dyn Error + 'static + --> $DIR/lifetime.rs:15:17 + | +15 | Foo(#[from] Generic<&'a str>), + | ^^^^^^^^^^^^^^^^