diff --git a/dissect/cstruct/cstruct.py b/dissect/cstruct/cstruct.py index 2907734..f5fc7cd 100644 --- a/dissect/cstruct/cstruct.py +++ b/dissect/cstruct/cstruct.py @@ -52,9 +52,12 @@ def __init__(self, load: str = "", *, endian: str = "<", pointer: str | None = N self.consts = {} self.lookups = {} + self.types = {} + self.typedefs = {} self.includes = [] + # fmt: off - self.typedefs = { + initial_types = { # Internal types "int8": self._make_packed_type("int8", "b", int), "uint8": self._make_packed_type("uint8", "B", int), @@ -98,6 +101,21 @@ def __init__(self, load: str = "", *, endian: str = "<", pointer: str | None = N "signed long long": "int64", "unsigned long long": "uint64", + # Other convenience types + "u1": "uint8", + "u2": "uint16", + "u4": "uint32", + "u8": "uint64", + "u16": "uint128", + "__u8": "uint8", + "__u16": "uint16", + "__u32": "uint32", + "__u64": "uint64", + "uchar": "uint8", + "ushort": "uint16", + "uint": "uint32", + "ulong": "uint32", + # Windows types "BYTE": "uint8", "CHAR": "char", @@ -165,24 +183,12 @@ def __init__(self, load: str = "", *, endian: str = "<", pointer: str | None = N "_DWORD": "uint32", "_QWORD": "uint64", "_OWORD": "uint128", - - # Other convenience types - "u1": "uint8", - "u2": "uint16", - "u4": "uint32", - "u8": "uint64", - "u16": "uint128", - "__u8": "uint8", - "__u16": "uint16", - "__u32": "uint32", - "__u64": "uint64", - "uchar": "uint8", - "ushort": "uint16", - "uint": "uint32", - "ulong": "uint32", } # fmt: on + for name, type_ in initial_types.items(): + self.add_type(name, type_) + pointer = pointer or ("uint64" if sys.maxsize > 2**32 else "uint32") self.pointer: type[BaseType] = self.resolve(pointer) self._anonymous_count = 0 @@ -196,37 +202,71 @@ def __getattr__(self, attr: str) -> Any: except KeyError: pass + try: + return self.types[attr] + except KeyError: + pass + try: return self.resolve(self.typedefs[attr]) except KeyError: pass - raise AttributeError(f"Invalid attribute: {attr}") + return super().__getattribute__(attr) def _next_anonymous(self) -> str: name = f"__anonymous_{self._anonymous_count}__" self._anonymous_count += 1 return name + def _add_attr(self, name: str, value: Any, replace: bool = False) -> None: + if not replace and (name in self.__dict__ and self.__dict__[name] != value): + raise ValueError(f"Attribute already exists: {name}") + setattr(self, name, value) + def add_type(self, name: str, type_: type[BaseType] | str, replace: bool = False) -> None: """Add a type or type reference. Only use this method when creating type aliases or adding already bound types. + All types will be resolved to their actual type objects prior to being added. + Use :func:`add_typedef` to add type references. Args: name: Name of the type to be added. type_: The type to be added. Can be a str reference to another type or a compatible type class. + If a str is given, it will be resolved to the actual type object. Raises: ValueError: If the type already exists. """ - if not replace and (name in self.typedefs and self.resolve(self.typedefs[name]) != self.resolve(type_)): + typeobj = self.resolve(type_) + if not replace and (name in self.types and self.types[name] != typeobj): raise ValueError(f"Duplicate type: {name}") - self.typedefs[name] = type_ + self.types[name] = typeobj + self._add_attr(name, typeobj, replace=replace) addtype = add_type + def add_typedef(self, name: str, type_: str, replace: bool = False) -> None: + """Add a type reference. + + Use this method to add type references to this cstruct instance. These are type names that can be + dynamically resolved at a later stage. Use :func:`add_type` to add actual type objects. + + Args: + name: Name of the type to be added. + type_: The type reference to be added. + replace: Whether to replace the type if it already exists. + """ + if not isinstance(type_, str): + raise TypeError("Type reference must be a string") + + if not replace and (name in self.typedefs and self.resolve(self.typedefs[name]) != self.resolve(type_)): + raise ValueError(f"Duplicate type: {name}") + + self.typedefs[name] = type_ + def add_custom_type( self, name: str, type_: type[BaseType], size: int | None = None, alignment: int | None = None, **kwargs ) -> None: @@ -244,6 +284,16 @@ def add_custom_type( """ self.add_type(name, self._make_type(name, (type_,), size, alignment=alignment, attrs=kwargs)) + def add_const(self, name: str, value: Any) -> None: + """Add a constant value. + + Args: + name: Name of the constant to be added. + value: The value of the constant. + """ + self.consts[name] = value + self._add_attr(name, value, replace=True) + def load(self, definition: str, deftype: int | None = None, **kwargs) -> cstruct: """Parse structures from the given definitions using the given definition type. @@ -315,14 +365,14 @@ def resolve(self, name: type[BaseType] | str) -> type[BaseType]: return type_name for _ in range(10): + if type_name in self.types: + return self.types[type_name] + if type_name not in self.typedefs: raise ResolveError(f"Unknown type {name}") type_name = self.typedefs[type_name] - if not isinstance(type_name, str): - return type_name - raise ResolveError(f"Recursion limit exceeded while resolving type {name}") def _make_type( diff --git a/dissect/cstruct/parser.py b/dissect/cstruct/parser.py index 32e23f0..07e1f95 100644 --- a/dissect/cstruct/parser.py +++ b/dissect/cstruct/parser.py @@ -153,7 +153,7 @@ def _constant(self, tokens: TokenConsumer) -> None: except (ExpressionParserError, ExpressionTokenizerError): pass - self.cstruct.consts[match["name"]] = value + self.cstruct.add_const(match["name"], value) def _undef(self, tokens: TokenConsumer) -> None: const = tokens.consume() @@ -204,7 +204,8 @@ def _enum(self, tokens: TokenConsumer) -> None: enum = factory(d["name"] or "", self.cstruct.resolve(d["type"]), values) if not enum.__name__: - self.cstruct.consts.update(enum.__members__) + for k, v in enum.__members__.items(): + self.cstruct.add_const(k, v) else: self.cstruct.add_type(enum.__name__, enum) @@ -212,12 +213,14 @@ def _enum(self, tokens: TokenConsumer) -> None: def _typedef(self, tokens: TokenConsumer) -> None: tokens.consume() + type_name = None type_ = None names = [] if tokens.next == self.TOK.IDENTIFIER: - type_ = self.cstruct.resolve(self._identifier(tokens)) + type_name = self._identifier(tokens) + type_ = self.cstruct.resolve(type_name) elif tokens.next == self.TOK.STRUCT: type_ = self._struct(tokens) if not type_.__anonymous__: @@ -230,10 +233,13 @@ def _typedef(self, tokens: TokenConsumer) -> None: type_.__name__ = name type_.__qualname__ = name - type_, name, bits = self._parse_field_type(type_, name) + new_type, name, bits = self._parse_field_type(type_, name) if bits is not None: raise ParserError(f"line {self._lineno(tokens.previous)}: typedefs cannot have bitfields") - self.cstruct.add_type(name, type_) + if type_name is None or new_type is not type_: + self.cstruct.add_type(name, new_type) + else: + self.cstruct.add_typedef(name, type_name) def _struct(self, tokens: TokenConsumer, register: bool = False) -> type[Structure]: stype = tokens.consume() @@ -496,7 +502,7 @@ def _constants(self, data: str) -> None: except (ValueError, SyntaxError): pass - self.cstruct.consts[d["name"]] = v + self.cstruct.add_const(d["name"], v) def _enums(self, data: str) -> None: r = re.finditer( @@ -578,7 +584,7 @@ def _structs(self, data: str) -> None: if d["defs"]: for td in d["defs"].strip().split(","): td = td.strip() - self.cstruct.add_type(td, st) + self.cstruct.add_typedef(td, st) def _parse_fields(self, data: str) -> None: fields = re.finditer( diff --git a/dissect/cstruct/tools/stubgen.py b/dissect/cstruct/tools/stubgen.py index 02aaebb..e43d091 100644 --- a/dissect/cstruct/tools/stubgen.py +++ b/dissect/cstruct/tools/stubgen.py @@ -79,32 +79,43 @@ def generate_cstruct_stub(cs: cstruct, module_prefix: str = "", cls_name: str = defined_names = set() - # Then typedefs - for name, typedef in cs.typedefs.items(): - if name in empty_cs.typedefs: + # Then types + for name, type_ in cs.types.items(): + if name in empty_cs.types: continue - if typedef.__name__ in empty_cs.typedefs: - stub = f"{name}: TypeAlias = {cs_prefix}{typedef.__name__}" - elif typedef.__name__ in defined_names: + if type_.__name__ in empty_cs.types: + stub = f"{name}: TypeAlias = {cs_prefix}{type_.__name__}" + elif type_.__name__ in defined_names: # Create an alias to the type if we have already seen it before. - stub = f"{name}: TypeAlias = {typedef.__name__}" - elif issubclass(typedef, (types.Enum, types.Flag)): - stub = generate_enum_stub(typedef, cs_prefix=cs_prefix, module_prefix=module_prefix) - elif issubclass(typedef, types.Pointer): - typehint = generate_typehint(typedef, prefix=cs_prefix, module_prefix=module_prefix) + stub = f"{name}: TypeAlias = {type_.__name__}" + elif issubclass(type_, (types.Enum, types.Flag)): + stub = generate_enum_stub(type_, cs_prefix=cs_prefix, module_prefix=module_prefix) + elif issubclass(type_, types.Pointer): + typehint = generate_typehint(type_, prefix=cs_prefix, module_prefix=module_prefix) stub = f"{name}: TypeAlias = {typehint}" - elif issubclass(typedef, types.Structure): - stub = generate_structure_stub(typedef, cs_prefix=cs_prefix, module_prefix=module_prefix) - elif issubclass(typedef, types.BaseType): - stub = generate_generic_stub(typedef, cs_prefix=cs_prefix, module_prefix=module_prefix) - elif isinstance(typedef, str): - stub = f"{name}: TypeAlias = {typedef}" + elif issubclass(type_, types.Structure): + stub = generate_structure_stub(type_, cs_prefix=cs_prefix, module_prefix=module_prefix) + elif issubclass(type_, types.BaseType): + stub = generate_generic_stub(type_, cs_prefix=cs_prefix, module_prefix=module_prefix) + elif isinstance(type_, str): + stub = f"{name}: TypeAlias = {type_}" else: - raise TypeError(f"Unknown typedef: {typedef}") + raise TypeError(f"Unknown type: {type_}") + + defined_names.add(type_.__name__) + + body.append(textwrap.indent(stub, prefix=indent)) + + # Then typedefs + for name, typedef in cs.typedefs.items(): + if name in empty_cs.typedefs: + continue - defined_names.add(typedef.__name__) + if not isinstance(typedef, str): + raise TypeError(f"Expected typedef to be a string, got {type(typedef)} for {name}") + stub = f"{name}: TypeAlias = {cs_prefix}{typedef}" body.append(textwrap.indent(stub, prefix=indent)) if not body: diff --git a/dissect/cstruct/types/structure.py b/dissect/cstruct/types/structure.py index f2e8e03..b26d12b 100644 --- a/dissect/cstruct/types/structure.py +++ b/dissect/cstruct/types/structure.py @@ -306,7 +306,7 @@ def _write(cls, stream: BinaryIO, data: Structure) -> int: num = 0 for field in cls.__fields__: - field_type = cls.cs.resolve(field.type) + field_type = field.type bit_field_type = ( (field_type.type if isinstance(field_type, EnumMetaType) else field_type) if field.bits else None @@ -515,7 +515,7 @@ def _read_fields( buf = io.BytesIO(stream.read(cls.size)) for field in cls.__fields__: - field_type = cls.cs.resolve(field.type) + field_type = field.type start = 0 if field.offset is not None: diff --git a/tests/test_basic.py b/tests/test_basic.py index bb0dd47..3113123 100644 --- a/tests/test_basic.py +++ b/tests/test_basic.py @@ -37,7 +37,7 @@ def test_load_file(cs: cstruct, compiled: bool, tmp_path: Path) -> None: tmp_path.joinpath("testdef.txt").write_text(textwrap.dedent(cdef)) cs.loadfile(tmp_path.joinpath("testdef.txt"), compiled=compiled) - assert "test" in cs.typedefs + assert "test" in cs.types def test_load_init() -> None: @@ -49,12 +49,12 @@ def test_load_init() -> None: """ # load with first positional argument cs = cstruct(cdef) - assert "test" in cs.typedefs + assert "test" in cs.types assert cs.endian == "<" # load from keyword argument and big endian cs = cstruct(load=cdef, endian=">") - assert "test" in cs.typedefs + assert "test" in cs.types a = cs.test(a=0xBADC0DE, b=0xACCE55ED) assert len(bytes(a)) == 12 assert bytes(a) == a.dumps() @@ -62,7 +62,7 @@ def test_load_init() -> None: # load using positional argument and little endian cs = cstruct(cdef, endian="<") - assert "test" in cs.typedefs + assert "test" in cs.types a = cs.test(a=0xBADC0DE, b=0xACCE55ED) assert len(bytes(a)) == 12 assert bytes(a) == a.dumps() @@ -81,7 +81,7 @@ def test_load_init_kwargs_only() -> None: cs = cstruct(cdef, ">") cs = cstruct(cdef, endian=">") - assert "test" in cs.typedefs + assert "test" in cs.types assert cs.endian == ">" @@ -97,7 +97,7 @@ def test_type_resolve(cs: cstruct) -> None: cs.add_type("ref0", "uint32") for i in range(1, 15): # Recursion limit is currently 10 - cs.add_type(f"ref{i}", f"ref{i - 1}") + cs.add_typedef(f"ref{i}", f"ref{i - 1}") with pytest.raises(ResolveError, match="Recursion limit exceeded"): cs.resolve("ref14") @@ -455,7 +455,7 @@ def test_reserved_keyword(cs: cstruct, compiled: bool) -> None: cs.load(cdef, compiled=compiled) for name in ["in", "class", "for"]: - assert name in cs.typedefs + assert name in cs.types assert verify_compiled(cs.resolve(name), compiled) assert cs.resolve(name)(b"\x01").a == 1 diff --git a/tests/test_parser.py b/tests/test_parser.py index 85a7906..12c79cc 100644 --- a/tests/test_parser.py +++ b/tests/test_parser.py @@ -119,7 +119,7 @@ def test_structure_names(cs: cstruct) -> None: """ cs.load(cdef) - assert all(c in cs.typedefs for c in ("a", "b", "c", "d", "e")) + assert all(c in cs.typedefs | cs.types for c in ("a", "b", "c", "d", "e")) assert cs.a.__name__ == "a" assert cs.b.__name__ == "b" @@ -188,7 +188,7 @@ def test_conditional_ifdef(cs: cstruct) -> None: """ cs.load(cdef) - assert "test" in cs.typedefs + assert "test" in cs.types def test_conditional_ifndef(cs: cstruct) -> None: @@ -218,7 +218,7 @@ def test_conditional_ifndef_guard(cs: cstruct) -> None: cs.load(cdef) assert "__MYGUARD" in cs.consts - assert "myStruct" in cs.typedefs + assert "myStruct" in cs.types def test_conditional_nested() -> None: @@ -265,7 +265,7 @@ def test_conditional_in_struct(cs: cstruct) -> None: """ cs.load(cdef) - assert "t_bitfield" in cs.typedefs + assert "t_bitfield" in cs.types assert "fval" in cs.t_bitfield.fields assert "bit0" in cs.t_bitfield.fields["fval"].type.fields assert "bit1" in cs.t_bitfield.fields["fval"].type.fields diff --git a/tests/test_tools_stubgen.py b/tests/test_tools_stubgen.py index 52e4f82..39837ab 100644 --- a/tests/test_tools_stubgen.py +++ b/tests/test_tools_stubgen.py @@ -288,9 +288,6 @@ def __init__(self, fh: bytes | memoryview | bytearray | BinaryIO, /): ... """, """ class cstruct(cstruct): - __fs16: TypeAlias = cstruct.uint16 - __fs32: TypeAlias = cstruct.uint32 - __fs64: TypeAlias = cstruct.uint64 class Test(Structure): a: cstruct.uint16 b: cstruct.uint32 @@ -300,6 +297,9 @@ def __init__(self, a: cstruct.uint16 | None = ..., b: cstruct.uint32 | None = .. @overload def __init__(self, fh: bytes | memoryview | bytearray | BinaryIO, /): ... + __fs16: TypeAlias = cstruct.__u16 + __fs32: TypeAlias = cstruct.__u32 + __fs64: TypeAlias = cstruct.__u64 """, # noqa: E501 id="typedef stub", ),