diff --git a/dissect/cstruct/parser.py b/dissect/cstruct/parser.py index 8f214f6..9f465e2 100644 --- a/dissect/cstruct/parser.py +++ b/dissect/cstruct/parser.py @@ -49,12 +49,18 @@ def __init__(self, cs: cstruct, compiled: bool = True, align: bool = False): self.compiled = compiled self.align = align self.TOK = self._tokencollection() + self._conditionals = [] + self._conditionals_depth = 0 @staticmethod def _tokencollection() -> TokenCollection: TOK = TokenCollection() TOK.add(r"#\[(?P[^\]]+)\](?=\s*)", "CONFIG_FLAG") - TOK.add(r"#define\s+(?P[^\s]+)\s+(?P[^\r\n]+)\s*", "DEFINE") + TOK.add(r"#define\s+(?P[^\s]+)(?P[^\r\n]*)", "DEFINE") + TOK.add(r"#ifdef\s+(?P[^\s]+)\s*", "IFDEF") + TOK.add(r"#ifndef\s+(?P[^\s]+)\s*", "IFNDEF") + TOK.add(r"#else\s*", "ELSE") + TOK.add(r"#endif\s*", "ENDIF") TOK.add(r"typedef(?=\s)", "TYPEDEF") TOK.add(r"(?:struct|union)(?=\s|{)", "STRUCT") TOK.add( @@ -80,12 +86,61 @@ def _identifier(self, tokens: TokenConsumer) -> str: idents.append(tokens.consume()) return " ".join([i.value for i in idents]) + def _conditional(self, tokens: TokenConsumer) -> None: + token = tokens.consume() + pattern = self.TOK.patterns[token.token] + match = pattern.match(token.value).groupdict() + + value = match["name"] + + if token.token == self.TOK.IFDEF: + self._conditionals.append(value in self.cstruct.consts) + elif token.token == self.TOK.IFNDEF: + self._conditionals.append(value not in self.cstruct.consts) + + def _check_conditional(self, tokens: TokenConsumer) -> bool: + """Check and handle conditionals. Return a boolean indicating if we need to continue to the next token.""" + if self._conditionals and self._conditionals_depth == len(self._conditionals): + # If we have a conditional and the depth matches, handle it accordingly + if tokens.next == self.TOK.ELSE: + # Flip the last conditional + tokens.consume() + self._conditionals[-1] = not self._conditionals[-1] + return True + + if tokens.next == self.TOK.ENDIF: + # Pop the last conditional + tokens.consume() + self._conditionals.pop() + self._conditionals_depth -= 1 + return True + + if tokens.next in (self.TOK.IFDEF, self.TOK.IFNDEF): + # If we encounter a new conditional, increase the depth + self._conditionals_depth += 1 + + if tokens.next == self.TOK.ENDIF: + # Similarly, decrease the depth if needed + self._conditionals_depth -= 1 + + if self._conditionals and not self._conditionals[-1]: + # If the last conditional evaluated to False, skip the next token + tokens.consume() + return True + + if tokens.next in (self.TOK.IFDEF, self.TOK.IFNDEF): + # If the next token is a conditional, process it + self._conditional(tokens) + return True + + return False + def _constant(self, tokens: TokenConsumer) -> None: const = tokens.consume() pattern = self.TOK.patterns[self.TOK.DEFINE] match = pattern.match(const.value).groupdict() - value = match["value"] + value = match["value"].strip() try: value = ast.literal_eval(value) except (ValueError, SyntaxError): @@ -208,6 +263,9 @@ def _struct(self, tokens: TokenConsumer, register: bool = False) -> type[Structu tokens.consume() break + if self._check_conditional(tokens): + continue + field = self._parse_field(tokens) fields.append(field) @@ -266,7 +324,7 @@ def _parse_field(self, tokens: TokenConsumer) -> Field: return Field(None, type_, None) if tokens.next != self.TOK.NAME: - raise ParserError(f"line {self._lineno(tokens.next)}: expected name") + raise ParserError(f"line {self._lineno(tokens.next)}: expected name, got {tokens.next!r}") nametok = tokens.consume() type_, name, bits = self._parse_field_type(type_, nametok.value) @@ -378,6 +436,9 @@ def parse(self, data: str) -> None: if token is None: break + if self._check_conditional(tokens): + continue + if token == self.TOK.CONFIG_FLAG: self._config_flag(tokens) elif token == self.TOK.DEFINE: @@ -395,6 +456,9 @@ def parse(self, data: str) -> None: else: raise ParserError(f"line {self._lineno(token)}: unexpected token {token!r}") + if self._conditionals: + raise ParserError(f"line {self._lineno(tokens.previous)}: unclosed conditional statement") + class CStyleParser(Parser): """Definition parser for C-like structure syntax. diff --git a/tests/test_parser.py b/tests/test_parser.py index d049756..71fe8ab 100644 --- a/tests/test_parser.py +++ b/tests/test_parser.py @@ -1,18 +1,15 @@ from __future__ import annotations -from typing import TYPE_CHECKING from unittest.mock import Mock import pytest +from dissect.cstruct import cstruct from dissect.cstruct.exceptions import ParserError from dissect.cstruct.parser import TokenParser from dissect.cstruct.types import BaseArray, Pointer, Structure from tests.utils import verify_compiled -if TYPE_CHECKING: - from dissect.cstruct import cstruct - def test_nested_structs(cs: cstruct, compiled: bool) -> None: cdef = """ @@ -164,3 +161,123 @@ def test_typedef_pointer(cs: cstruct) -> None: assert cs.IMAGE_DATA_DIRECTORY is cs._IMAGE_DATA_DIRECTORY assert issubclass(cs.PIMAGE_DATA_DIRECTORY, Pointer) assert cs.PIMAGE_DATA_DIRECTORY.type == cs._IMAGE_DATA_DIRECTORY + + +def test_conditional_ifdef(cs: cstruct) -> None: + cdef = """ + #define MY_CONST 42 + + #ifdef MY_CONST + struct test { + uint32 a; + }; + #endif + """ + cs.load(cdef) + + assert "test" in cs.typedefs + + +def test_conditional_ifndef(cs: cstruct) -> None: + cdef = """ + #ifndef MYVAR + #define MYVAR (1) + #endif + """ + cs.load(cdef) + + assert "MYVAR" in cs.consts + assert cs.consts["MYVAR"] == 1 + + +def test_conditional_ifndef_guard(cs: cstruct) -> None: + cdef = """ + /* Define Guard */ + #ifndef __MYGUARD + #define __MYGUARD + + typedef struct myStruct + { + char charVal[16]; + } + #endif // __MYGUARD + """ + cs.load(cdef) + + assert "__MYGUARD" in cs.consts + assert "myStruct" in cs.typedefs + + +def test_conditional_nested() -> None: + cdef = """ + #ifndef MYSWITCH1 + #define MYVAR1 (1) + #else + #ifdef MYSWITCH2 + #define MYVAR1 (2) + #else + #define MYVAR1 (3) + #endif + #endif + """ + cs = cstruct().load(cdef) + + assert "MYVAR1" in cs.consts + assert cs.consts["MYVAR1"] == 1 + + cs = cstruct().load("#define MYSWITCH1") + + assert "MYSWITCH1" in cs.consts + + cs.load(cdef) + + assert "MYVAR1" in cs.consts + assert cs.consts["MYVAR1"] == 3 + + +def test_conditional_in_struct(cs: cstruct) -> None: + cdef = """ + struct t_bitfield { + union { + struct { + uint32_t bit0:1; + uint32_t bit1:1; + #ifdef MYSWT + uint32_t bit2:1; + #endif + } fval; + uint32_t bits; + }; + }; + """ + cs.load(cdef) + + assert "t_bitfield" in cs.typedefs + assert "fval" in cs.t_bitfield.fields + assert "bit0" in cs.t_bitfield.fields["fval"].type.fields + assert "bit1" in cs.t_bitfield.fields["fval"].type.fields + assert "bit2" not in cs.t_bitfield.fields["fval"].type.fields + + +def test_conditional_parsing_error(cs: cstruct) -> None: + cdef = """ + #ifndef __HELP + #define __HELP + #endif + struct test { + uint32 a; + }; + #endif + """ + with pytest.raises(ParserError, match="line 8: unexpected token .+ENDIF"): + cs.load(cdef) + + cdef = """ + #ifndef __HELP + #define __HELP + struct test { + uint32 a; + }; + """ + with pytest.raises(ParserError, match="line 6: unclosed conditional statement"): + cs.load(cdef)