From 49ac55de554b08182f44d28451f31d06279891e9 Mon Sep 17 00:00:00 2001 From: Octavia Togami Date: Thu, 10 Aug 2023 20:45:35 -0700 Subject: [PATCH 1/4] Allow `self` in non-unit struct and enum top-level assert Also allows top-level map to use `self` instead of `this` --- binrw/tests/derive/enum.rs | 37 ++++++++++++++ binrw/tests/derive/map_args.rs | 24 ++++++++++ binrw/tests/derive/struct.rs | 26 ++++++++++ binrw_derive/src/binrw/codegen/mod.rs | 1 + .../src/binrw/codegen/read_options/enum.rs | 13 ++--- .../src/binrw/codegen/read_options/struct.rs | 48 +++++++++++++++---- binrw_derive/src/binrw/parser/types/assert.rs | 19 +++++++- 7 files changed, 151 insertions(+), 17 deletions(-) diff --git a/binrw/tests/derive/enum.rs b/binrw/tests/derive/enum.rs index 572ed41b..50306dbe 100644 --- a/binrw/tests/derive/enum.rs +++ b/binrw/tests/derive/enum.rs @@ -28,6 +28,43 @@ fn enum_assert() { Test::read_le(&mut Cursor::new(b"\0\0\x01")).expect_err("accepted bad data"); } +#[test] +fn enum_assert_with_self() { + #[derive(BinRead, Debug, PartialEq)] + #[br(assert(self.verify()))] + enum Test { + A { + a: u8, + b: u8, + }, + #[br(assert(self.verify_only_b()))] + B { + a: i16, + b: u8, + }, + } + + impl Test { + fn verify(&self) -> bool { + match self { + Test::A { b, .. } => *b == 1, + Test::B { a, b } => *a == -1 && *b == 1, + } + } + + fn verify_only_b(&self) -> bool { + matches!(self, Test::B { .. }) + } + } + + assert_eq!( + Test::read_le(&mut Cursor::new(b"\xff\xff\x01")).unwrap(), + Test::B { a: -1, b: 1 } + ); + Test::read_le(&mut Cursor::new(b"\xff\xff\0")).expect_err("accepted bad data"); + Test::read_le(&mut Cursor::new(b"\0\0\x01")).expect_err("accepted bad data"); +} + #[test] fn enum_non_copy_args() { #[derive(BinRead, Debug)] diff --git a/binrw/tests/derive/map_args.rs b/binrw/tests/derive/map_args.rs index 968ba60d..cf2e0ef5 100644 --- a/binrw/tests/derive/map_args.rs +++ b/binrw/tests/derive/map_args.rs @@ -48,3 +48,27 @@ fn map_field_assert_access_fields() { Test::read(&mut Cursor::new(b"a")).unwrap(); } + +#[test] +#[should_panic] +fn map_top_assert_legacy_this() { + #[derive(BinRead, Debug, Eq, PartialEq)] + #[br(assert(this.x == 2), map(|_: u8| Test { x: 3 }))] + struct Test { + x: u8, + } + + Test::read(&mut Cursor::new(b"a")).unwrap(); +} + +#[test] +#[should_panic] +fn map_top_assert_via_self() { + #[derive(BinRead, Debug, Eq, PartialEq)] + #[br(assert(self.x == 2), map(|_: u8| Test { x: 3 }))] + struct Test { + x: u8, + } + + Test::read(&mut Cursor::new(b"a")).unwrap(); +} diff --git a/binrw/tests/derive/struct.rs b/binrw/tests/derive/struct.rs index 097ed8c8..d54542fe 100644 --- a/binrw/tests/derive/struct.rs +++ b/binrw/tests/derive/struct.rs @@ -704,6 +704,32 @@ fn reader_var() { ); } +#[test] +fn top_level_assert_has_self() { + #[allow(dead_code)] + #[derive(BinRead, Debug)] + #[br(assert(self.verify(), "verify failed"))] + struct Test { + a: u8, + b: u8, + } + + impl Test { + fn verify(&self) -> bool { + self.a == self.b + } + } + + let mut data = Cursor::new(b"\x01\x01"); + Test::read_le(&mut data).expect("a == b passed"); + let mut data = Cursor::new(b"\x01\x02"); + let err = Test::read_le(&mut data).expect_err("a == b failed"); + assert!(matches!(err, binrw::Error::AssertFail { + message, + .. + } if message == "verify failed")); +} + #[test] fn rewind_on_assert() { #[allow(dead_code)] diff --git a/binrw_derive/src/binrw/codegen/mod.rs b/binrw_derive/src/binrw/codegen/mod.rs index 700d1622..dcdfda3b 100644 --- a/binrw_derive/src/binrw/codegen/mod.rs +++ b/binrw_derive/src/binrw/codegen/mod.rs @@ -204,6 +204,7 @@ fn get_assertions(assertions: &[Assert]) -> impl Iterator + kw_span, condition, consequent, + .. }| { let error_fn = match &consequent { Some(AssertionError::Message(message)) => { diff --git a/binrw_derive/src/binrw/codegen/read_options/enum.rs b/binrw_derive/src/binrw/codegen/read_options/enum.rs index f3e96327..010300ff 100644 --- a/binrw_derive/src/binrw/codegen/read_options/enum.rs +++ b/binrw_derive/src/binrw/codegen/read_options/enum.rs @@ -3,12 +3,9 @@ use super::{ PreludeGenerator, }; use crate::binrw::{ - codegen::{ - get_assertions, - sanitization::{ - BACKTRACE_FRAME, BIN_ERROR, ERROR_BASKET, OPT, POS, READER, READ_METHOD, - RESTORE_POSITION_VARIANT, TEMP, WITH_CONTEXT, - }, + codegen::sanitization::{ + BACKTRACE_FRAME, BIN_ERROR, ERROR_BASKET, OPT, POS, READER, READ_METHOD, + RESTORE_POSITION_VARIANT, TEMP, WITH_CONTEXT, }, parser::{Enum, EnumErrorMode, EnumVariant, Input, UnitEnumField, UnitOnlyEnum}, }; @@ -228,8 +225,8 @@ fn generate_variant_impl(en: &Enum, variant: &EnumVariant) -> TokenStream { None, Some(&format!("{}::{}", en.ident.as_ref().unwrap(), &ident)), ) - .add_assertions(get_assertions(&en.assertions)) - .return_value(Some(ident)) + .initialize_value_with_assertions(Some(ident), &en.assertions) + .return_value() .finish(), EnumVariant::Unit(options) => generate_unit_struct(&input, None, Some(&options.ident)), diff --git a/binrw_derive/src/binrw/codegen/read_options/struct.rs b/binrw_derive/src/binrw/codegen/read_options/struct.rs index b2044763..09b7e1ba 100644 --- a/binrw_derive/src/binrw/codegen/read_options/struct.rs +++ b/binrw_derive/src/binrw/codegen/read_options/struct.rs @@ -20,6 +20,7 @@ use alloc::borrow::Cow; use proc_macro2::TokenStream; use quote::{quote, quote_spanned, ToTokens}; use syn::{spanned::Spanned, Ident}; +use crate::binrw::parser::Assert; pub(super) fn generate_unit_struct( input: &Input, @@ -37,8 +38,8 @@ pub(super) fn generate_unit_struct( pub(super) fn generate_struct(input: &Input, name: Option<&Ident>, st: &Struct) -> TokenStream { StructGenerator::new(input, st) .read_fields(name, None) - .add_assertions(core::iter::empty()) - .return_value(None) + .initialize_value_with_assertions(None, &[]) + .return_value() .finish() } @@ -61,11 +62,31 @@ impl<'input> StructGenerator<'input> { self.out } - pub(super) fn add_assertions( - mut self, - extra_assertions: impl Iterator, + pub(super) fn initialize_value_with_assertions( + self, + variant_ident: Option<&Ident>, + extra_assertions: &[Assert], ) -> Self { - let assertions = get_assertions(&self.st.assertions).chain(extra_assertions); + if self.has_self_assertions(extra_assertions) { + self.init_value(variant_ident) + .add_assertions(extra_assertions) + } else { + self.add_assertions(extra_assertions) + .init_value(variant_ident) + } + } + + fn has_self_assertions(&self, extra_assertions: &[Assert]) -> bool { + self.st + .assertions + .iter() + .chain(extra_assertions) + .any(|assert| assert.condition_uses_self) + } + + fn add_assertions(mut self, extra_assertions: &[Assert]) -> Self { + let assertions = get_assertions(&self.st.assertions) + .chain(get_assertions(extra_assertions)); let head = self.out; self.out = quote! { #head @@ -102,7 +123,7 @@ impl<'input> StructGenerator<'input> { self } - pub(super) fn return_value(mut self, variant_ident: Option<&Ident>) -> Self { + fn init_value(mut self, variant_ident: Option<&Ident>) -> Self { let out_names = self.st.iter_permanent_idents(); let return_type = get_return_type(variant_ident); let return_value = if self.st.is_tuple() { @@ -114,7 +135,18 @@ impl<'input> StructGenerator<'input> { let head = self.out; self.out = quote! { #head - Ok(#return_value) + let this = #return_value; + }; + + self + } + + pub(super) fn return_value(mut self) -> Self { + let head = self.out; + + self.out = quote! { + #head + Ok(this) }; self diff --git a/binrw_derive/src/binrw/parser/types/assert.rs b/binrw_derive/src/binrw/parser/types/assert.rs index ccb34b36..cea122f3 100644 --- a/binrw_derive/src/binrw/parser/types/assert.rs +++ b/binrw_derive/src/binrw/parser/types/assert.rs @@ -1,5 +1,5 @@ use crate::{binrw::parser::attrs, meta_types::KeywordToken}; -use proc_macro2::{Span, TokenStream}; +use proc_macro2::{Ident, Span, TokenStream, TokenTree}; use quote::{quote, ToTokens}; use syn::{parse::Parse, spanned::Spanned, token::Token, Expr, ExprLit, Lit}; @@ -13,6 +13,9 @@ pub(crate) enum Error { pub(crate) struct Assert { pub(crate) kw_span: Span, pub(crate) condition: TokenStream, + /// `true` if the condition was written with `self`, in the [`condition`] it is replaced with + /// `this`. This enables backwards compatibility with asserts that did not use `self`. + pub(crate) condition_uses_self: bool, pub(crate) consequent: Option, } @@ -35,6 +38,19 @@ impl TryFrom> for Assert { )); }; + // ignores any alternative declaration of `self` in the condition, but asserts should be + // simple so that shouldn't be a problem + let mut condition_uses_self = false; + let condition: TokenStream = condition.into_iter().map(|tt| { + match tt { + TokenTree::Ident(ref i) if i == "self" => { + condition_uses_self = true; + TokenTree::Ident(Ident::new("this", i.span())) + } + other => other, + } + }).collect(); + let consequent = match args.next() { Some(Expr::Lit(ExprLit { lit: Lit::Str(message), @@ -53,6 +69,7 @@ impl TryFrom> for Assert { Ok(Self { kw_span, condition, + condition_uses_self, consequent, }) } From 410782cfcaffc47e7abe286f417d604bb0861252 Mon Sep 17 00:00:00 2001 From: Octavia Togami Date: Thu, 10 Aug 2023 21:44:27 -0700 Subject: [PATCH 2/4] Add some docs for self in asserts --- binrw/doc/attribute.md | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/binrw/doc/attribute.md b/binrw/doc/attribute.md index 7ebe3396..a978115f 100644 --- a/binrw/doc/attribute.md +++ b/binrw/doc/attribute.md @@ -502,6 +502,15 @@ Any (earlier only, when reading)earlie field or [import](#arguments) can be referenced by expressions in the directive. +
+ +For `#[br]`, when using `map`, a non-unit `struct`, or an `enum`, a special variable +named `self` can be referenced by expressions in the directive. It contains the result +of the `map` function or the result of constructing the `struct` or `enum`. Note that +you cannot refer to the `enum` fields directly, as an `enum` variant is not its own type. + +
+ ## Examples ### Formatted error From 69da8727270f8be7624804014214ae9a7b53adc2 Mon Sep 17 00:00:00 2001 From: Octavia Togami Date: Thu, 10 Aug 2023 21:54:26 -0700 Subject: [PATCH 3/4] Fix formatting --- binrw_derive/src/binrw/codegen/read_options/struct.rs | 6 +++--- binrw_derive/src/binrw/parser/types/assert.rs | 9 +++++---- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/binrw_derive/src/binrw/codegen/read_options/struct.rs b/binrw_derive/src/binrw/codegen/read_options/struct.rs index 09b7e1ba..bd872e9b 100644 --- a/binrw_derive/src/binrw/codegen/read_options/struct.rs +++ b/binrw_derive/src/binrw/codegen/read_options/struct.rs @@ -1,6 +1,7 @@ use super::{get_magic, PreludeGenerator}; #[cfg(feature = "verbose-backtrace")] use crate::binrw::backtrace::BacktraceFrame; +use crate::binrw::parser::Assert; use crate::{ binrw::{ codegen::{ @@ -20,7 +21,6 @@ use alloc::borrow::Cow; use proc_macro2::TokenStream; use quote::{quote, quote_spanned, ToTokens}; use syn::{spanned::Spanned, Ident}; -use crate::binrw::parser::Assert; pub(super) fn generate_unit_struct( input: &Input, @@ -85,8 +85,8 @@ impl<'input> StructGenerator<'input> { } fn add_assertions(mut self, extra_assertions: &[Assert]) -> Self { - let assertions = get_assertions(&self.st.assertions) - .chain(get_assertions(extra_assertions)); + let assertions = + get_assertions(&self.st.assertions).chain(get_assertions(extra_assertions)); let head = self.out; self.out = quote! { #head diff --git a/binrw_derive/src/binrw/parser/types/assert.rs b/binrw_derive/src/binrw/parser/types/assert.rs index cea122f3..48b44382 100644 --- a/binrw_derive/src/binrw/parser/types/assert.rs +++ b/binrw_derive/src/binrw/parser/types/assert.rs @@ -41,15 +41,16 @@ impl TryFrom> for Assert { // ignores any alternative declaration of `self` in the condition, but asserts should be // simple so that shouldn't be a problem let mut condition_uses_self = false; - let condition: TokenStream = condition.into_iter().map(|tt| { - match tt { + let condition: TokenStream = condition + .into_iter() + .map(|tt| match tt { TokenTree::Ident(ref i) if i == "self" => { condition_uses_self = true; TokenTree::Ident(Ident::new("this", i.span())) } other => other, - } - }).collect(); + }) + .collect(); let consequent = match args.next() { Some(Expr::Lit(ExprLit { From b52e37994e41201e849e3f959f6653331a067677 Mon Sep 17 00:00:00 2001 From: Octavia Togami Date: Thu, 10 Aug 2023 22:21:04 -0700 Subject: [PATCH 4/4] Fix nested self usage with syn::Fold --- binrw/tests/derive/struct.rs | 26 +++++++++++++ binrw_derive/src/binrw/parser/types/assert.rs | 39 +++++++++++-------- 2 files changed, 48 insertions(+), 17 deletions(-) diff --git a/binrw/tests/derive/struct.rs b/binrw/tests/derive/struct.rs index d54542fe..8a9d831d 100644 --- a/binrw/tests/derive/struct.rs +++ b/binrw/tests/derive/struct.rs @@ -730,6 +730,32 @@ fn top_level_assert_has_self() { } if message == "verify failed")); } +#[test] +fn top_level_assert_self_weird() { + #[allow(dead_code)] + #[derive(BinRead, Debug)] + #[br(assert(Test::verify(&self), "verify failed"))] + struct Test { + a: u8, + b: u8, + } + + impl Test { + fn verify(&self) -> bool { + self.a == self.b + } + } + + let mut data = Cursor::new(b"\x01\x01"); + Test::read_le(&mut data).expect("a == b passed"); + let mut data = Cursor::new(b"\x01\x02"); + let err = Test::read_le(&mut data).expect_err("a == b failed"); + assert!(matches!(err, binrw::Error::AssertFail { + message, + .. + } if message == "verify failed")); +} + #[test] fn rewind_on_assert() { #[allow(dead_code)] diff --git a/binrw_derive/src/binrw/parser/types/assert.rs b/binrw_derive/src/binrw/parser/types/assert.rs index 48b44382..d7bd6b1d 100644 --- a/binrw_derive/src/binrw/parser/types/assert.rs +++ b/binrw_derive/src/binrw/parser/types/assert.rs @@ -1,6 +1,7 @@ use crate::{binrw::parser::attrs, meta_types::KeywordToken}; -use proc_macro2::{Ident, Span, TokenStream, TokenTree}; +use proc_macro2::{Ident, Span, TokenStream}; use quote::{quote, ToTokens}; +use syn::fold::Fold; use syn::{parse::Parse, spanned::Spanned, token::Token, Expr, ExprLit, Lit}; #[derive(Debug, Clone)] @@ -26,9 +27,7 @@ impl TryFrom> for Assert { let kw_span = value.keyword_span(); let mut args = value.fields.iter(); - let condition = if let Some(cond) = args.next() { - cond.into_token_stream() - } else { + let Some(cond_expr) = args.next() else { return Err(Self::Error::new( kw_span, format!( @@ -40,17 +39,8 @@ impl TryFrom> for Assert { // ignores any alternative declaration of `self` in the condition, but asserts should be // simple so that shouldn't be a problem - let mut condition_uses_self = false; - let condition: TokenStream = condition - .into_iter() - .map(|tt| match tt { - TokenTree::Ident(ref i) if i == "self" => { - condition_uses_self = true; - TokenTree::Ident(Ident::new("this", i.span())) - } - other => other, - }) - .collect(); + let mut self_replacer = ReplaceSelfWithThis { uses_self: false }; + let cond_expr = self_replacer.fold_expr(cond_expr.clone()); let consequent = match args.next() { Some(Expr::Lit(ExprLit { @@ -69,9 +59,24 @@ impl TryFrom> for Assert { Ok(Self { kw_span, - condition, - condition_uses_self, + condition: cond_expr.into_token_stream(), + condition_uses_self: self_replacer.uses_self, consequent, }) } } + +struct ReplaceSelfWithThis { + uses_self: bool, +} + +impl Fold for ReplaceSelfWithThis { + fn fold_ident(&mut self, i: Ident) -> Ident { + if i == "self" { + self.uses_self = true; + Ident::new("this", i.span()) + } else { + i + } + } +}