diff --git a/dissect/util/_build.py b/dissect/util/_build.py index fd0e347..fa3ed7e 100644 --- a/dissect/util/_build.py +++ b/dissect/util/_build.py @@ -1,4 +1,5 @@ # Reference: https://setuptools.pypa.io/en/latest/build_meta.html#dynamic-build-dependencies-and-other-build-meta-tweaks +# type: ignore from __future__ import annotations import os diff --git a/dissect/util/_native/hash/__init__.py b/dissect/util/_native/hash/__init__.pyi similarity index 100% rename from dissect/util/_native/hash/__init__.py rename to dissect/util/_native/hash/__init__.pyi diff --git a/dissect/util/_native/hash/crc32c.py b/dissect/util/_native/hash/crc32c.pyi similarity index 100% rename from dissect/util/_native/hash/crc32c.py rename to dissect/util/_native/hash/crc32c.pyi diff --git a/dissect/util/compression/__init__.py b/dissect/util/compression/__init__.py index 3661b55..7f7d0ac 100644 --- a/dissect/util/compression/__init__.py +++ b/dissect/util/compression/__init__.py @@ -23,8 +23,8 @@ try: from dissect.util import _native - lz4 = lz4_native = _native.compression.lz4 - lzo = lzo_native = _native.compression.lzo + lz4 = lz4_native = _native.compression.lz4 # type: ignore + lzo = lzo_native = _native.compression.lzo # type: ignore except (ImportError, AttributeError): lz4_native = lzo_native = None diff --git a/dissect/util/compression/lz4.py b/dissect/util/compression/lz4.py index 6736bd0..214b7a5 100644 --- a/dissect/util/compression/lz4.py +++ b/dissect/util/compression/lz4.py @@ -25,10 +25,10 @@ def _get_length(src: BinaryIO, length: int) -> int: def decompress( - src: bytes | BinaryIO, + src: bytes | bytearray | memoryview | BinaryIO, uncompressed_size: int = -1, return_bytearray: bool = False, -) -> bytes | tuple[bytes, int]: +) -> bytes | bytearray | tuple[bytes | bytearray, int]: """LZ4 decompress from a file-like object or bytes up to a certain length. Assumes no header. Args: @@ -39,7 +39,7 @@ def decompress( Returns: The decompressed data. """ - if not hasattr(src, "read"): + if isinstance(src, bytes | bytearray | memoryview): src = io.BytesIO(src) dst = bytearray() diff --git a/dissect/util/compression/lzbitmap.py b/dissect/util/compression/lzbitmap.py index 46c921a..139cef0 100644 --- a/dissect/util/compression/lzbitmap.py +++ b/dissect/util/compression/lzbitmap.py @@ -11,7 +11,7 @@ _H = struct.Struct(" bytes: +def decompress(src: bytes | bytearray | memoryview | BinaryIO) -> bytes: """LZBITMAP decompress from a file-like object or bytes. Decompresses until EOF or EOS of the input data. @@ -22,7 +22,7 @@ def decompress(src: bytes | BinaryIO) -> bytes: Returns: The decompressed data. """ - if not hasattr(src, "read"): + if isinstance(src, bytes | bytearray | memoryview): src = io.BytesIO(src) if src.read(4) != b"ZBM\x09": @@ -54,7 +54,7 @@ def decompress(src: bytes | BinaryIO) -> bytes: buf = memoryview(src.read(compressed_size)) # Build the bitmap/token map - token_map = [] + token_map: list[tuple[int | None, int]] = [] bits = int.from_bytes(buf[-17:], "little") for i in range(0xF): if i < 3: @@ -97,7 +97,7 @@ def decompress(src: bytes | BinaryIO) -> bytes: for _ in range(repeat): bitmap, token = token_map[idx] - if idx < 3: + if bitmap is None: # idx < 3, but this makes the type checker happy # Index 0, 1, 2 are special and indicate we need to read a bitmap from the bitmap region bitmap = buf[bitmap_offset] bitmap_offset += 1 diff --git a/dissect/util/compression/lzfse.py b/dissect/util/compression/lzfse.py index 0c6e945..32f9b4e 100644 --- a/dissect/util/compression/lzfse.py +++ b/dissect/util/compression/lzfse.py @@ -398,7 +398,7 @@ def _decode_lmd( return bytes(dst) -def decompress(src: bytes | BinaryIO) -> bytes: +def decompress(src: bytes | bytearray | memoryview | BinaryIO) -> bytes: """LZFSE decompress from a file-like object or bytes. Decompresses until EOF or EOS of the input data. @@ -409,7 +409,7 @@ def decompress(src: bytes | BinaryIO) -> bytes: Returns: The decompressed data. """ - if not hasattr(src, "read"): + if isinstance(src, bytes | bytearray | memoryview): src = io.BytesIO(src) dst = bytearray() diff --git a/dissect/util/compression/lznt1.py b/dissect/util/compression/lznt1.py index 5981b86..b4ccf7f 100644 --- a/dissect/util/compression/lznt1.py +++ b/dissect/util/compression/lznt1.py @@ -25,7 +25,7 @@ def _get_displacement(offset: int) -> int: TAG_MASKS = [(1 << i) for i in range(8)] -def decompress(src: bytes | BinaryIO) -> bytes: +def decompress(src: bytes | bytearray | memoryview | BinaryIO) -> bytes: """LZNT1 decompress from a file-like object or bytes. Args: @@ -34,7 +34,7 @@ def decompress(src: bytes | BinaryIO) -> bytes: Returns: The decompressed data. """ - if not hasattr(src, "read"): + if isinstance(src, bytes | bytearray | memoryview): src = io.BytesIO(src) offset = src.tell() diff --git a/dissect/util/compression/lzo.py b/dissect/util/compression/lzo.py index 8158225..696abb8 100644 --- a/dissect/util/compression/lzo.py +++ b/dissect/util/compression/lzo.py @@ -23,7 +23,7 @@ def _read_length(src: BinaryIO, val: int, mask: int) -> int: return length + mask + val -def decompress(src: bytes | BinaryIO, header: bool = True, buflen: int = -1) -> bytes: +def decompress(src: bytes | bytearray | memoryview | BinaryIO, header: bool = True, buflen: int = -1) -> bytes: """LZO decompress from a file-like object or bytes. Assumes no header. Arguments are largely compatible with python-lzo API. @@ -36,7 +36,7 @@ def decompress(src: bytes | BinaryIO, header: bool = True, buflen: int = -1) -> Returns: The decompressed data. """ - if not hasattr(src, "read"): + if isinstance(src, bytes | bytearray | memoryview): src = io.BytesIO(src) dst = bytearray() diff --git a/dissect/util/compression/lzvn.py b/dissect/util/compression/lzvn.py index 9819372..1dd7ebc 100644 --- a/dissect/util/compression/lzvn.py +++ b/dissect/util/compression/lzvn.py @@ -56,7 +56,7 @@ _H = struct.Struct(" bytes: +def decompress(src: bytes | bytearray | memoryview | BinaryIO) -> bytes: """LZVN decompress from a file-like object or bytes. Decompresses until EOF or EOS of the input data. @@ -67,7 +67,7 @@ def decompress(src: bytes | BinaryIO) -> bytes: Returns: The decompressed data. """ - if not hasattr(src, "read"): + if isinstance(src, bytes | bytearray | memoryview): src = io.BytesIO(src) offset = src.tell() @@ -207,7 +207,7 @@ def decompress(src: bytes | BinaryIO) -> bytes: if src_size < opc_len: break - src_size -= opc_len + L + src_size -= opc_len break elif opc in OP_UDEF: diff --git a/dissect/util/compression/lzxpress.py b/dissect/util/compression/lzxpress.py index 74996e9..2d19496 100644 --- a/dissect/util/compression/lzxpress.py +++ b/dissect/util/compression/lzxpress.py @@ -6,7 +6,7 @@ from typing import BinaryIO -def decompress(src: bytes | BinaryIO) -> bytes: +def decompress(src: bytes | bytearray | memoryview | BinaryIO) -> bytes: """LZXPRESS decompress from a file-like object or bytes. Args: @@ -15,7 +15,7 @@ def decompress(src: bytes | BinaryIO) -> bytes: Returns: The decompressed data. """ - if not hasattr(src, "read"): + if isinstance(src, bytes | bytearray | memoryview): src = io.BytesIO(src) offset = src.tell() diff --git a/dissect/util/compression/lzxpress_huffman.py b/dissect/util/compression/lzxpress_huffman.py index 9c73b8b..f709fab 100644 --- a/dissect/util/compression/lzxpress_huffman.py +++ b/dissect/util/compression/lzxpress_huffman.py @@ -19,10 +19,10 @@ def _read_16_bit(fh: BinaryIO) -> int: class Node: __slots__ = ("children", "is_leaf", "symbol") - def __init__(self, symbol: Symbol | None = None, is_leaf: bool = False): + def __init__(self, symbol: int = 0, is_leaf: bool = False): self.symbol = symbol self.is_leaf = is_leaf - self.children = [None, None] + self.children: dict[int, Node] = {} def _add_leaf(nodes: list[Node], idx: int, mask: int, bits: int) -> int: @@ -32,7 +32,7 @@ def _add_leaf(nodes: list[Node], idx: int, mask: int, bits: int) -> int: while bits > 1: bits -= 1 childidx = (mask >> bits) & 1 - if node.children[childidx] is None: + if childidx not in node.children: node.children[childidx] = nodes[i] nodes[i].is_leaf = False i += 1 @@ -84,24 +84,28 @@ def _build_tree(buf: bytes) -> Node: class BitString: - def __init__(self): - self.source = None + def __init__(self, fh: BinaryIO): + self.fh = fh self.mask = 0 self.bits = 0 @property def index(self) -> int: - return self.source.tell() + return self.fh.tell() - def init(self, fh: BinaryIO) -> None: - self.mask = (_read_16_bit(fh) << 16) + _read_16_bit(fh) + def reset(self) -> None: + self.mask = (_read_16_bit(self.fh) << 16) + _read_16_bit(self.fh) self.bits = 32 - self.source = fh def read(self, n: int) -> bytes: - return self.source.read(n) + return self.fh.read(n) - def lookup(self, n: int) -> int: + def take(self, n: int) -> int: + value = self.peek(n) + self.skip(n) + return value + + def peek(self, n: int) -> int: if n == 0: return 0 @@ -111,19 +115,19 @@ def skip(self, n: int) -> None: self.mask = (self.mask << n) & 0xFFFFFFFF self.bits -= n if self.bits < 16: - self.mask += _read_16_bit(self.source) << (16 - self.bits) + self.mask += _read_16_bit(self.fh) << (16 - self.bits) self.bits += 16 - def decode(self, root: Node) -> Symbol: + def decode(self, root: Node) -> int: node = root while not node.is_leaf: - bit = self.lookup(1) - self.skip(1) + bit = self.take(1) node = node.children[bit] + return node.symbol -def decompress(src: bytes | BinaryIO) -> bytes: +def decompress(src: bytes | bytearray | memoryview | BinaryIO) -> bytes: """LZXPRESS decompress from a file-like object or bytes. Decompresses until EOF of the input data. @@ -134,7 +138,7 @@ def decompress(src: bytes | BinaryIO) -> bytes: Returns: The decompressed data. """ - if not hasattr(src, "read"): + if isinstance(src, bytes | bytearray | memoryview): src = io.BytesIO(src) dst = bytearray() @@ -144,11 +148,11 @@ def decompress(src: bytes | BinaryIO) -> bytes: size = src.tell() - start_offset src.seek(start_offset, io.SEEK_SET) - bitstring = BitString() + bitstring = BitString(src) while src.tell() - start_offset < size: root = _build_tree(src.read(256)) - bitstring.init(src) + bitstring.reset() chunk_size = 0 while chunk_size < 65536 and src.tell() - start_offset < size: @@ -161,13 +165,13 @@ def decompress(src: bytes | BinaryIO) -> bytes: length = symbol & 0x0F symbol >>= 4 - offset = (1 << symbol) + bitstring.lookup(symbol) + offset = (1 << symbol) + bitstring.peek(symbol) if length == 15: length = ord(bitstring.read(1)) + 15 if length == 270: - length = _read_16_bit(bitstring.source) + length = _read_16_bit(bitstring.fh) bitstring.skip(symbol) diff --git a/dissect/util/compression/sevenbit.py b/dissect/util/compression/sevenbit.py index 58e6fba..f72b251 100644 --- a/dissect/util/compression/sevenbit.py +++ b/dissect/util/compression/sevenbit.py @@ -1,10 +1,10 @@ from __future__ import annotations -from io import BytesIO +import io from typing import BinaryIO -def compress(src: bytes | BinaryIO) -> bytes: +def compress(src: bytes | bytearray | memoryview | BinaryIO) -> bytes: """Sevenbit compress from a file-like object or bytes. Args: @@ -13,8 +13,8 @@ def compress(src: bytes | BinaryIO) -> bytes: Returns: The compressed data. """ - if not hasattr(src, "read"): - src = BytesIO(src) + if isinstance(src, bytes | bytearray | memoryview): + src = io.BytesIO(src) dst = bytearray() @@ -39,7 +39,7 @@ def compress(src: bytes | BinaryIO) -> bytes: return bytes(dst) -def decompress(src: bytes | BinaryIO, wide: bool = False) -> bytes: +def decompress(src: bytes | bytearray | memoryview | BinaryIO, wide: bool = False) -> bytes: """Sevenbit decompress from a file-like object or bytes. Args: @@ -48,8 +48,8 @@ def decompress(src: bytes | BinaryIO, wide: bool = False) -> bytes: Returns: The decompressed data. """ - if not hasattr(src, "read"): - src = BytesIO(src) + if isinstance(src, bytes | bytearray | memoryview): + src = io.BytesIO(src) dst = bytearray() diff --git a/dissect/util/compression/xz.py b/dissect/util/compression/xz.py index 826288b..ec0e746 100644 --- a/dissect/util/compression/xz.py +++ b/dissect/util/compression/xz.py @@ -8,7 +8,7 @@ CRC_SIZE = 4 -def repair_checksum(fh: BinaryIO) -> BinaryIO: +def repair_checksum(fh: BinaryIO) -> OverlayStream: """Repair CRC32 checksums for all headers in an XZ stream. FortiOS XZ files have (on purpose) corrupt streams which they read using a modified ``xz`` binary. @@ -55,7 +55,7 @@ def repair_checksum(fh: BinaryIO) -> BinaryIO: # Parse the index isize, num_records = _mbi(index[1:]) index = index[1 + isize : -4] - records = [] + records: list[tuple[int, int]] = [] for _ in range(num_records): if not index: raise ValueError("Missing index size") diff --git a/dissect/util/cpio.py b/dissect/util/cpio.py index 404878b..03ae15c 100644 --- a/dissect/util/cpio.py +++ b/dissect/util/cpio.py @@ -3,8 +3,8 @@ import stat import struct import tarfile -from tarfile import InvalidHeaderError -from typing import BinaryIO +from tarfile import EmptyHeaderError, InvalidHeaderError # type: ignore +from typing import Any, BinaryIO, cast FORMAT_CPIO_BIN = 10 FORMAT_CPIO_ODC = 11 @@ -39,8 +39,20 @@ class CpioInfo(tarfile.TarInfo): """ + format: int + _mode: int + magic: int + ino: int + nlink: int + rdevmajor: int + rdevminor: int + namesize: int + @classmethod def fromtarfile(cls, tarfile: tarfile.TarFile) -> CpioInfo: + if not tarfile.fileobj: + raise RuntimeError("Invalid tarfile state") + if tarfile.format not in ( FORMAT_CPIO_BIN, FORMAT_CPIO_ODC, @@ -49,7 +61,7 @@ def fromtarfile(cls, tarfile: tarfile.TarFile) -> CpioInfo: FORMAT_CPIO_HPBIN, FORMAT_CPIO_HPODC, ): - tarfile.format = detect_header(tarfile.fileobj) + tarfile.format = detect_header(cast("BinaryIO", tarfile.fileobj)) # type: ignore if tarfile.format in (FORMAT_CPIO_BIN, FORMAT_CPIO_HPBIN): buf = tarfile.fileobj.read(26) @@ -58,29 +70,33 @@ def fromtarfile(cls, tarfile: tarfile.TarFile) -> CpioInfo: elif tarfile.format in (FORMAT_CPIO_NEWC, FORMAT_CPIO_CRC): buf = tarfile.fileobj.read(110) else: - raise InvalidHeaderError("Unknown cpio type") + raise InvalidHeaderError("Unknown cpio type") # type: ignore - obj = cls.frombuf(buf, tarfile.format, tarfile.encoding, tarfile.errors) - obj.format = tarfile.format + obj = cls.frombuf(buf, tarfile.encoding, tarfile.errors, format=tarfile.format) # type: ignore + obj.format = cast("int", tarfile.format) obj.offset = tarfile.fileobj.tell() - len(buf) return obj._proc_member(tarfile) @classmethod - def frombuf(cls, buf: bytes, format: int, encoding: str, errors: str) -> CpioInfo: + def frombuf( + cls, buf: bytes | bytearray, encoding: str | None, errors: str, format: int = FORMAT_CPIO_UNKNOWN + ) -> CpioInfo: if format in (FORMAT_CPIO_BIN, FORMAT_CPIO_ODC, FORMAT_CPIO_HPBIN, FORMAT_CPIO_HPODC): obj = cls._old_frombuf(buf, format) elif format in (FORMAT_CPIO_NEWC, FORMAT_CPIO_CRC): obj = cls._new_frombuf(buf, format) + else: + raise InvalidHeaderError("Unknown cpio type") # Common postprocessing ftype = stat.S_IFMT(obj._mode) - obj.type = TYPE_MAP.get(ftype, ftype) + obj.type = TYPE_MAP.get(ftype, tarfile.REGTYPE) obj.mode = stat.S_IMODE(obj._mode) return obj @classmethod - def _old_frombuf(cls, buf: bytes, format: int) -> CpioInfo: + def _old_frombuf(cls, buf: bytes | bytearray, format: int) -> CpioInfo: if format in (FORMAT_CPIO_BIN, FORMAT_CPIO_HPBIN): values = list(struct.unpack("<13H", buf)) if values[0] == _swap16(CPIO_MAGIC_OLD): @@ -94,7 +110,7 @@ def _old_frombuf(cls, buf: bytes, format: int) -> CpioInfo: values = [int(v, 8) for v in struct.unpack("<6s6s6s6s6s6s6s6s11s6s11s", buf)] if values[0] != CPIO_MAGIC_OLD: - raise InvalidHeaderError(f"Invalid (old) ASCII/binary cpio header magic: {oct(values[0])}") + raise tarfile.InvalidHeaderError(f"Invalid (old) ASCII/binary cpio header magic: {oct(values[0])}") # type: ignore obj = cls() obj.devmajor = values[1] >> 8 @@ -133,11 +149,11 @@ def _old_frombuf(cls, buf: bytes, format: int) -> CpioInfo: return obj @classmethod - def _new_frombuf(cls, buf: bytes, format: int) -> CpioInfo: + def _new_frombuf(cls, buf: bytes | bytearray, format: int) -> CpioInfo: values = struct.unpack("<6s8s8s8s8s8s8s8s8s8s8s8s8s8s", buf) values = [int(values[0], 8)] + [int(v, 16) for v in values[1:]] if values[0] not in (CPIO_MAGIC_NEW, CPIO_MAGIC_CRC): - raise InvalidHeaderError(f"Invalid (new) ASCII cpio header magic: {oct(values[0])}") + raise InvalidHeaderError(f"Invalid (new) ASCII cpio header magic: {oct(values[0])}") # type: ignore obj = cls() obj._mode = values[2] @@ -159,11 +175,14 @@ def _new_frombuf(cls, buf: bytes, format: int) -> CpioInfo: return obj - def _proc_member(self, tarfile: tarfile.TarFile) -> CpioInfo | None: - self.name = tarfile.fileobj.read(self.namesize - 1).decode(tarfile.encoding, tarfile.errors) + def _proc_member(self, tarfile: tarfile.TarFile) -> CpioInfo: + if not tarfile.fileobj: + raise RuntimeError("Invalid tarfile state") + + self.name = tarfile.fileobj.read(self.namesize - 1).decode(tarfile.encoding or "utf-8", tarfile.errors) if self.name == "TRAILER!!!": # The last entry in a cpio file has the special name ``TRAILER!!!``, indicating the end of the archive - return None + raise EmptyHeaderError("End of cpio archive") # type: ignore offset = tarfile.fileobj.tell() + 1 self.offset_data = self._round_word(offset) @@ -171,7 +190,7 @@ def _proc_member(self, tarfile: tarfile.TarFile) -> CpioInfo | None: if self.issym(): tarfile.fileobj.seek(self.offset_data) - self.linkname = tarfile.fileobj.read(self.size).decode(tarfile.encoding, tarfile.errors) + self.linkname = tarfile.fileobj.read(self.size).decode(tarfile.encoding or "utf-8", tarfile.errors) self.size = 0 return self @@ -187,7 +206,7 @@ def _round_word(self, offset: int) -> int: def issocket(self) -> bool: """Return True if it is a socket.""" - return self.type == stat.S_IFSOCK + return self._mode == stat.S_IFSOCK def detect_header(fh: BinaryIO) -> int: @@ -216,11 +235,13 @@ def _swap16(value: int) -> int: def CpioFile(*args, **kwargs) -> tarfile.TarFile: # noqa: N802 """Utility wrapper around ``tarfile.TarFile`` to easily open cpio archives.""" + kwargs["tarinfo"] = CpioInfo kwargs.setdefault("format", FORMAT_CPIO_UNKNOWN) - return tarfile.TarFile(*args, **kwargs, tarinfo=CpioInfo) + return tarfile.TarFile(*args, **kwargs) -def open(*args, **kwargs) -> tarfile.TarFile: +def open(*args: Any, **kwargs: Any) -> tarfile.TarFile: """Utility wrapper around ``tarfile.open`` to easily open cpio archives.""" + kwargs["tarinfo"] = CpioInfo kwargs.setdefault("format", FORMAT_CPIO_UNKNOWN) - return tarfile.open(*args, **kwargs, tarinfo=CpioInfo) + return tarfile.open(*args, **kwargs) diff --git a/dissect/util/encoding/surrogateescape.py b/dissect/util/encoding/surrogateescape.py index 7240cfe..3ac0b4d 100644 --- a/dissect/util/encoding/surrogateescape.py +++ b/dissect/util/encoding/surrogateescape.py @@ -5,7 +5,7 @@ def error_handler(error: Exception) -> tuple[str, int]: if not isinstance(error, UnicodeDecodeError): raise error - result = [] + result: list[str] = [] for i in range(error.start, error.end): byte = error.object[i] if byte < 128: diff --git a/dissect/util/hash/__init__.py b/dissect/util/hash/__init__.py index c2fb5f1..327f58e 100644 --- a/dissect/util/hash/__init__.py +++ b/dissect/util/hash/__init__.py @@ -15,7 +15,7 @@ try: from dissect.util import _native - crc32c = crc32c_native = _native.hash.crc32c + crc32c = crc32c_native = _native.hash.crc32c # type: ignore except (ImportError, AttributeError): crc32c_native = None diff --git a/dissect/util/hash/jenkins.py b/dissect/util/hash/jenkins.py index 8f49a83..6be2c4d 100644 --- a/dissect/util/hash/jenkins.py +++ b/dissect/util/hash/jenkins.py @@ -1,7 +1,7 @@ from struct import unpack -def _mix64(a: int, b: int, c: int) -> int: +def _mix64(a: int, b: int, c: int) -> tuple[int, int, int]: """Mixes three 64-bit values reversibly.""" # Implement logical right shift by masking first a = (a - b - c) ^ ((c & 0xFFFFFFFFFFFFFFFF) >> 43) diff --git a/dissect/util/ldap.py b/dissect/util/ldap.py index 145cf79..baf0a3a 100644 --- a/dissect/util/ldap.py +++ b/dissect/util/ldap.py @@ -48,7 +48,7 @@ def __init__(self, query: str) -> None: self.query: str = query self.children: list[SearchFilter] = [] - self.operator: LogicalOperator | ComparisonOperator | None = None + self.operator: LogicalOperator | ComparisonOperator = None self.attribute: str | None = None self.value: str | None = None self._extended_rule: str | None = None diff --git a/dissect/util/plist.py b/dissect/util/plist.py index baa4f78..8ee7d39 100644 --- a/dissect/util/plist.py +++ b/dissect/util/plist.py @@ -3,40 +3,130 @@ import plistlib import uuid from collections import UserDict -from typing import TYPE_CHECKING, Any, BinaryIO +from datetime import datetime +from typing import TYPE_CHECKING, Any, BinaryIO, TypeAlias, TypedDict, Union, cast from dissect.util.ts import cocoatimestamp if TYPE_CHECKING: - from datetime import datetime + from collections.abc import Callable, Iterable + + +_Value: TypeAlias = Union[ + bool, + bytes, + int, + float, + str, + datetime, + uuid.UUID, + "NSDictionary", + "NSObject", + "_NSClass", + list["_Value"], + dict[str, "_Value"], + None, +] + + +_NSKeyedArchiver = TypedDict( + "_NSKeyedArchiver", + { + "$version": int, + "$archiver": str, + "$top": dict[str, plistlib.UID], + "$objects": list["_Value | _NSObject"], + }, +) +_NSClass = TypedDict( + "_NSClass", + { + "$classname": str, + "$classes": list[str], + }, +) +_NSObject = TypedDict( + "_NSObject", + { + "$class": plistlib.UID, + }, +) +_NSArray = TypedDict( + "_NSArray", + { + "$class": plistlib.UID, + "NS.objects": list[_Value | plistlib.UID], + }, +) +_NSMutableArray = _NSMutableSet = _NSSet = _NSArray +_NSDictionary = TypedDict( + "_NSDictionary", + { + "$class": plistlib.UID, + "NS.keys": list[plistlib.UID], + "NS.objects": list[plistlib.UID], + }, +) +_NSMutableDictionary = _NSDictionary +_NSData = TypedDict( + "_NSData", + { + "$class": plistlib.UID, + "NS.data": bytes, + }, +) +_NSMutableData = _NSData +_NSDate = TypedDict( + "_NSDate", + { + "$class": plistlib.UID, + "NS.time": int, + }, +) +_NSUUID = TypedDict( + "_NSUUID", + { + "$class": plistlib.UID, + "NS.uuidbytes": bytes, + }, +) +_NSURL = TypedDict( + "_NSURL", + { + "$class": plistlib.UID, + "NS.base": plistlib.UID, + "NS.relative": plistlib.UID, + }, +) class NSKeyedArchiver: def __init__(self, fh: BinaryIO): - self.plist = plistlib.load(fh) + plist: Any = plistlib.load(fh) - if not isinstance(self.plist, dict) or not all( - key in self.plist for key in ["$version", "$archiver", "$top", "$objects"] + if not isinstance(plist, dict) or not all( + key in plist for key in ["$version", "$archiver", "$top", "$objects"] ): raise ValueError("File is not an NSKeyedArchiver plist") - self._objects = self.plist.get("$objects") - self._cache = {} + self.plist: _NSKeyedArchiver = cast("_NSKeyedArchiver", plist) + self._objects = self.plist.get("$objects", []) + self._cache: dict[int, _Value] = {} - self.top = {} + self.top: dict[str, _Value] = {} for name, value in self.plist.get("$top", {}).items(): self.top[name] = self._parse(value) - def __getitem__(self, key: str) -> Any: + def __getitem__(self, key: str) -> _Value: return self.top[key] def __repr__(self) -> str: return f"" - def get(self, key: str, default: Any | None = None) -> Any: + def get(self, key: str, default: _Value | None = None) -> _Value: return self.top.get(key, default) - def _parse(self, uid: Any) -> Any: + def _parse(self, uid: _Value | plistlib.UID) -> _Value: if not isinstance(uid, plistlib.UID): return uid @@ -47,13 +137,12 @@ def _parse(self, uid: Any) -> Any: self._cache[num] = result return result - def _parse_obj(self, obj: Any) -> Any: + def _parse_obj(self, obj: _Value | _NSObject) -> _Value: if isinstance(obj, dict): - klass = obj.get("$class") - if klass: - klass_name = self._parse(klass).get("$classname") + if klass := obj.get("$class"): + klass_name = cast("_NSClass", self._parse(klass)).get("$classname") return CLASSES.get(klass_name, NSObject)(self, obj) - return obj + return cast("dict[Any, Any]", obj) if isinstance(obj, list): return list(map(self._parse, obj)) @@ -68,79 +157,80 @@ def _parse_obj(self, obj: Any) -> Any: class NSObject: - def __init__(self, nskeyed: NSKeyedArchiver, obj: dict[str, Any]): + def __init__(self, nskeyed: NSKeyedArchiver, obj: _NSObject): self.nskeyed = nskeyed self.obj = obj - self._class = nskeyed._parse(obj.get("$class", {})) + self._class: _NSClass = cast("_NSClass", nskeyed._parse(obj.get("$class", {}))) self._classname = self._class.get("$classname", "Unknown") self._classes = self._class.get("$classes", []) - def __getitem__(self, attr: str) -> Any: - obj = self.obj[attr] - return self.nskeyed._parse(obj) + def __getitem__(self, key: str) -> _Value: + return self.nskeyed._parse(cast("_Value", self.obj[key])) - def __getattr__(self, attr: str) -> Any: + def __getattr__(self, attr: str) -> _Value: try: return self[attr] except KeyError: raise AttributeError(attr) - def __repr__(self): + def __repr__(self) -> str: return f"<{self._classname}>" - def keys(self) -> list[str]: + def keys(self) -> Iterable[str]: return self.obj.keys() - def get(self, attr: str, default: Any | None = None) -> Any: + def get(self, attr: str, default: object | None = None) -> object | None: try: return self[attr] except KeyError: return default -class NSDictionary(UserDict, NSObject): - def __init__(self, nskeyed: NSKeyedArchiver, obj: dict[str, Any]): +class NSDictionary(UserDict, NSObject): # type: ignore + def __init__(self, nskeyed: NSKeyedArchiver, obj: _NSDictionary): NSObject.__init__(self, nskeyed, obj) - self.data = {nskeyed._parse(key): obj for key, obj in zip(obj["NS.keys"], obj["NS.objects"], strict=False)} + self.data: dict[_Value, Any] = { + nskeyed._parse(key): obj for key, obj in zip(obj["NS.keys"], obj["NS.objects"], strict=False) + } def __repr__(self) -> str: return NSObject.__repr__(self) - def __getitem__(self, key: str) -> Any: + def __getitem__(self, key: str) -> _Value: return self.nskeyed._parse(self.data[key]) -def parse_nsarray(nskeyed: NSKeyedArchiver, obj: dict[str, Any]) -> list[Any]: +def parse_nsarray(nskeyed: NSKeyedArchiver, obj: _NSArray | _NSMutableArray) -> list[_Value]: return list(map(nskeyed._parse, obj["NS.objects"])) -def parse_nsset(nskeyed: NSKeyedArchiver, obj: dict[str, Any]) -> list[Any]: +def parse_nsset(nskeyed: NSKeyedArchiver, obj: _NSSet | _NSMutableSet) -> list[_Value]: # Some values are not hashable, so return as list return parse_nsarray(nskeyed, obj) -def parse_nsdata(nskeyed: NSKeyedArchiver, obj: dict[str, Any]) -> Any: +def parse_nsdata(nskeyed: NSKeyedArchiver, obj: _NSData) -> bytes: return obj["NS.data"] -def parse_nsdate(nskeyed: NSKeyedArchiver, obj: dict[str, Any]) -> datetime: +def parse_nsdate(nskeyed: NSKeyedArchiver, obj: _NSDate) -> datetime: return cocoatimestamp(obj["NS.time"]) -def parse_nsuuid(nskeyed: NSKeyedArchiver, obj: dict[str, Any]) -> uuid.UUID: +def parse_nsuuid(nskeyed: NSKeyedArchiver, obj: _NSUUID) -> uuid.UUID: return uuid.UUID(bytes=obj["NS.uuidbytes"]) -def parse_nsurl(nskeyed: NSKeyedArchiver, obj: dict[str, Any]) -> str: - base = nskeyed._parse(obj["NS.base"]) - relative = nskeyed._parse(obj["NS.relative"]) +def parse_nsurl(nskeyed: NSKeyedArchiver, obj: _NSURL) -> str: + base = cast("str", nskeyed._parse(obj["NS.base"])) + relative = cast("str", nskeyed._parse(obj["NS.relative"])) if base: return f"{base}/{relative}" return relative -CLASSES = { +CLASSES: dict[str, Callable[[NSKeyedArchiver, Any], _Value]] = { "NSArray": parse_nsarray, "NSMutableArray": parse_nsarray, "NSDictionary": NSDictionary, diff --git a/dissect/util/stream.py b/dissect/util/stream.py index 7dc3b84..0d5f223 100644 --- a/dissect/util/stream.py +++ b/dissect/util/stream.py @@ -6,7 +6,10 @@ import zlib from bisect import bisect_left, bisect_right from threading import Lock -from typing import BinaryIO +from typing import TYPE_CHECKING, BinaryIO + +if TYPE_CHECKING: + from _typeshed import WriteableBuffer STREAM_BUFFER_SIZE = int(os.getenv("DISSECT_STREAM_BUFFER_SIZE", io.DEFAULT_BUFFER_SIZE)) @@ -43,7 +46,7 @@ def __init__(self, size: int | None = None, align: int = STREAM_BUFFER_SIZE): self._pos = 0 self._pos_align = 0 - self._buf = None + self._buf = b"" self._lock = Lock() def readable(self) -> bool: @@ -88,7 +91,7 @@ def _set_pos(self, pos: int) -> None: if self._pos_align != new_pos_align: self._pos_align = new_pos_align - self._buf = None + self._buf = b"" self._pos = pos @@ -166,16 +169,19 @@ def read(self, n: int = -1) -> bytes: # Misaligned remaining bytes if n > 0: self._fill_buf() + r.append(self._buf[:n]) self._set_pos(self._pos + n) return b"".join(r) - def readinto(self, b: bytearray) -> int: + def readinto(self, b: WriteableBuffer) -> int: """Read bytes into a pre-allocated bytes-like object b. Returns an int representing the number of bytes read (0 for EOF). """ + b = memoryview(b) + buf = self.read(len(b)) length = len(buf) b[:length] = buf @@ -294,6 +300,8 @@ class MappingStream(AlignedStream): align: The alignment size. """ + size: int + def __init__(self, size: int | None = None, align: int = STREAM_BUFFER_SIZE): super().__init__(size, align) self._runs: list[tuple[int, int, BinaryIO, int]] = [] @@ -311,10 +319,10 @@ def add(self, offset: int, size: int, fh: BinaryIO, file_offset: int = 0) -> Non """ self._runs.append((offset, size, fh, file_offset)) self._runs = sorted(self._runs, key=lambda run: run[0]) - self._buf = None + self._buf = b"" self.size = self._runs[-1][0] + self._runs[-1][1] - def _get_run_idx(self, offset: int) -> tuple[int, int, BinaryIO, int]: + def _get_run_idx(self, offset: int) -> int: """Find a mapping run for a given offset. Args: @@ -333,7 +341,7 @@ def _get_run_idx(self, offset: int) -> tuple[int, int, BinaryIO, int]: raise EOFError(f"No mapping for offset {offset}") def _read(self, offset: int, length: int) -> bytes: - result = [] + result: list[bytes] = [] run_idx = self._get_run_idx(offset) runlist_len = len(self._runs) @@ -384,8 +392,10 @@ class RunlistStream(AlignedStream): align: Optional alignment that differs from the block size, otherwise ``block_size`` is used as alignment. """ + size: int + def __init__( - self, fh: BinaryIO, runlist: list[tuple[int, int]], size: int, block_size: int, align: int | None = None + self, fh: BinaryIO, runlist: list[tuple[int | None, int]], size: int, block_size: int, align: int | None = None ): super().__init__(size, align or block_size) @@ -401,13 +411,13 @@ def __init__( self.block_size = block_size @property - def runlist(self) -> list[tuple[int, int]]: + def runlist(self) -> list[tuple[int | None, int]]: return self._runlist @runlist.setter - def runlist(self, runlist: list[tuple[int, int]]) -> None: + def runlist(self, runlist: list[tuple[int | None, int]]) -> None: self._runlist = runlist - self._runlist_offsets = [] + self._runlist_offsets: list[int] = [] offset = 0 # Create a list of starting offsets for each run so we can bisect that quickly when reading @@ -416,10 +426,10 @@ def runlist(self, runlist: list[tuple[int, int]]) -> None: self._runlist_offsets.append(offset) offset += block_count - self._buf = None + self._buf = b"" def _read(self, offset: int, length: int) -> bytes: - r = [] + r: list[bytes] = [] block_offset = offset // self.block_size @@ -479,7 +489,9 @@ def __init__(self, fh: BinaryIO, size: int | None = None, align: int = STREAM_BU self.overlays: dict[int, tuple[int, BinaryIO]] = {} self._lookup: list[int] = [] - def add(self, offset: int, data: bytes | BinaryIO, size: int | None = None) -> None: + def add( + self, offset: int, data: bytes | bytearray | memoryview | BinaryIO, size: int | None = None + ) -> OverlayStream: """Add an overlay at the given offset. Args: @@ -487,14 +499,14 @@ def add(self, offset: int, data: bytes | BinaryIO, size: int | None = None) -> N data: The bytes or file-like object to overlay. size: Optional size specification of the overlay, if it can't be inferred. """ - if not hasattr(data, "read"): + if isinstance(data, bytes | bytearray | memoryview): size = size or len(data) data = io.BytesIO(data) elif size is None: - size = data.size if hasattr(data, "size") else data.seek(0, io.SEEK_END) + size = getattr(data, "size", None) or data.seek(0, io.SEEK_END) if not size: - return None + return self if size < 0: raise ValueError("Size must be positive") @@ -510,12 +522,12 @@ def add(self, offset: int, data: bytes | BinaryIO, size: int | None = None) -> N # Clear the buffer if we add an overlay at our current position if self._buf and (self._pos_align <= offset + size and offset <= self._pos_align + len(self._buf)): - self._buf = None + self._buf = b"" return self def _read(self, offset: int, length: int) -> bytes: - result = [] + result: list[bytes] = [] fh = self._fh overlays = self.overlays @@ -590,21 +602,30 @@ class ZlibStream(AlignedStream): size: The size the stream should be. """ - def __init__(self, fh: BinaryIO, size: int | None = None, align: int = STREAM_BUFFER_SIZE, **kwargs): + def __init__( + self, + fh: BinaryIO, + size: int | None = None, + align: int = STREAM_BUFFER_SIZE, + *, + wbits: int = 15, + zdict: bytes = b"", + ): self._fh = fh - self._zlib = None - self._zlib_args = kwargs + self._fh.seek(0) + self._zlib_wbits = wbits + self._zlib_zdict = zdict + self._zlib = zlib.decompressobj(wbits=self._zlib_wbits, zdict=self._zlib_zdict) self._zlib_offset = 0 self._zlib_prepend = b"" self._zlib_prepend_offset = None - self._rewind() super().__init__(size, align) def _rewind(self) -> None: self._fh.seek(0) - self._zlib = zlib.decompressobj(**self._zlib_args) + self._zlib = zlib.decompressobj(wbits=self._zlib_wbits, zdict=self._zlib_zdict) self._zlib_offset = 0 self._zlib_prepend = b"" self._zlib_prepend_offset = None @@ -635,7 +656,7 @@ def _read_zlib(self, length: int) -> bytes: if length < 0: return self.readall() - result = [] + result: list[bytes] = [] while length > 0: buf = self._read_fh(io.DEFAULT_BUFFER_SIZE) decompressed = self._zlib.decompress(buf, length) @@ -661,7 +682,7 @@ def _read(self, offset: int, length: int) -> bytes: def readall(self) -> bytes: self._seek_zlib(self.tell()) - chunks = [] + chunks: list[bytes] = [] # sys.maxsize means the max length of output buffer is unlimited, # so that the whole input buffer can be decompressed within one # .decompress() call. diff --git a/dissect/util/tools/dump_nskeyedarchiver.py b/dissect/util/tools/dump_nskeyedarchiver.py index 7d12de5..2d362bb 100644 --- a/dissect/util/tools/dump_nskeyedarchiver.py +++ b/dissect/util/tools/dump_nskeyedarchiver.py @@ -1,7 +1,7 @@ from __future__ import annotations import argparse -from typing import Any +from typing import Any, cast from dissect.util.plist import NSKeyedArchiver, NSObject @@ -15,13 +15,13 @@ def main() -> None: try: obj = NSKeyedArchiver(fh) except ValueError as e: - parser.exit(str(e)) + parser.exit(1, str(e)) print(obj) print_object(obj.top) -def print_object(obj: Any, indent: int = 0, seen: set | None = None) -> None: +def print_object(obj: Any, indent: int = 0, seen: set[Any] | None = None) -> None: if seen is None: seen = set() @@ -33,7 +33,7 @@ def print_object(obj: Any, indent: int = 0, seen: set | None = None) -> None: pass if isinstance(obj, list): - for i, v in enumerate(obj): + for i, v in enumerate(cast("list[Any]", obj)): print(fmt(f"[{i}]:", indent)) print_object(v, indent + 1, seen) @@ -45,7 +45,7 @@ def print_object(obj: Any, indent: int = 0, seen: set | None = None) -> None: except TypeError: pass - for k in sorted(obj.keys()): + for k in sorted(cast("dict[Any, Any]", obj).keys()): print(fmt(f"{k}:", indent + 1)) print_object(obj[k], indent + 2, seen) diff --git a/dissect/util/xmemoryview.py b/dissect/util/xmemoryview.py index 773f00e..8a8e600 100644 --- a/dissect/util/xmemoryview.py +++ b/dissect/util/xmemoryview.py @@ -2,25 +2,41 @@ import struct import sys -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Literal, TypeAlias if TYPE_CHECKING: from collections.abc import Iterator - -def xmemoryview(view: bytes, format: str) -> memoryview | _xmemoryview: +# fmt: off +_Formats: TypeAlias = Literal[ + "@h", "=h", "h", "!h", + "@H" ,"=H", "H", "!H", + "@i", "=i", "i", "!i", + "@I", "=I", "I", "!I", + "@l", "=l", "l", "!l", + "@L", "=L", "L", "!L", + "@q", "=q", "q", "!q", + "@Q", "=Q", "Q", "!Q", +] +# fmt: on + + +def xmemoryview(view: bytes | bytearray | memoryview[int], format: _Formats) -> memoryview[int] | _xmemoryview[int]: """Cast a memoryview to the specified format, including endianness. The regular ``memoryview.cast()`` method only supports host endianness. While that should be fine 99% of the time (most of the world runs on little endian systems), we'd rather it be fine 100% of the time. This utility method ensures that by transparently converting between endianness if it doesn't match the host endianness. + While this should technically work on any format supported by ``memoryview.cast()``, it only makes sense to use it + for integer formats, and thus the typing is limited to those. + If the host endianness matches the requested endianness, this simply returns a regular ``memoryview.cast()``. See ``memoryview.cast()`` for more details on what that actually does. Args: - view: The bytes object or memoryview to cast. + buf: The bytes object or memoryview to cast. format: The format to cast to in ``struct`` format syntax. Raises: @@ -33,7 +49,7 @@ def xmemoryview(view: bytes, format: str) -> memoryview | _xmemoryview: if isinstance(view, bytes | bytearray): view = memoryview(view) - if not isinstance(view, memoryview): + if not isinstance(view, memoryview): # type: ignore raise TypeError("view must be a memoryview, bytes or bytearray object") endian = format[0] @@ -68,43 +84,44 @@ def __init__(self, view: memoryview, format: str): self._struct_to = struct.Struct(format) def tolist(self) -> list[int]: - return self._convert_from_native(self._view.tolist()) + return list(self._convert_from_native(self._view.tolist())) - def _convert_from_native(self, value: list[int] | int) -> int: + def _convert_from_native(self, value: list[int] | int) -> tuple[int, ...]: if isinstance(value, list): endian = self._format[0] fmt = self._format[1] pck = f"{len(value)}{fmt}" - return list(struct.unpack(f"{endian}{pck}", struct.pack(f"={pck}", *value))) - return self._struct_to.unpack(self._struct_frm.pack(value))[0] + return struct.unpack(f"{endian}{pck}", struct.pack(f"={pck}", *value)) + return self._struct_to.unpack(self._struct_frm.pack(value)) - def _convert_to_native(self, value: list[int] | int) -> int: + def _convert_to_native(self, value: list[int] | int) -> tuple[int, ...]: if isinstance(value, list): endian = self._format[0] fmt = self._format[1] pck = f"{len(value)}{fmt}" - return list(struct.unpack(f"={pck}", struct.pack(f"{endian}{pck}", *value))) - return self._struct_frm.unpack(self._struct_to.pack(value))[0] + return struct.unpack(f"={pck}", struct.pack(f"{endian}{pck}", *value)) + return self._struct_frm.unpack(self._struct_to.pack(value)) - def __getitem__(self, idx: int | slice) -> int | bytes: - value = self._view[idx] + def __getitem__(self, idx: int | slice) -> int | _xmemoryview: if isinstance(idx, int): - return self._convert_from_native(value) + return self._convert_from_native(self._view[idx])[0] if isinstance(idx, slice): return _xmemoryview(self._view[idx], self._format) raise TypeError("Invalid index type") def __setitem__(self, idx: int | slice, value: list[int] | int) -> None: - if isinstance(idx, int | slice): - self._view[idx] = self._convert_to_native(value) + if isinstance(idx, int): + self._view[idx] = self._convert_to_native(value)[0] + elif isinstance(idx, slice): + self._view[idx] = list(self._convert_to_native(value)) # type: ignore else: raise TypeError("Invalid index type") def __len__(self) -> int: return len(self._view) - def __eq__(self, other: memoryview | _xmemoryview) -> bool: + def __eq__(self, other: object) -> bool: if isinstance(other, _xmemoryview): other = other._view return self._view.__eq__(other) diff --git a/pyproject.toml b/pyproject.toml index aa0e4dd..705b09a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,6 +40,7 @@ test = [ lint = [ "ruff==0.12.12", "vermin", + "ty", ] build = [ "build", diff --git a/tests/compression/test_lz4.py b/tests/compression/test_lz4.py index 7b6d43f..7a1ab15 100644 --- a/tests/compression/test_lz4.py +++ b/tests/compression/test_lz4.py @@ -8,7 +8,7 @@ if TYPE_CHECKING: from types import ModuleType - from pytest_benchmark.fixture import BenchmarkFixture + from pytest_benchmark.fixture import BenchmarkFixture # type: ignore @pytest.mark.parametrize( diff --git a/tests/compression/test_lzbitmap.py b/tests/compression/test_lzbitmap.py index ea4e8d4..8e75d81 100644 --- a/tests/compression/test_lzbitmap.py +++ b/tests/compression/test_lzbitmap.py @@ -8,7 +8,7 @@ from dissect.util.compression import lzbitmap if TYPE_CHECKING: - from pytest_benchmark.fixture import BenchmarkFixture + from pytest_benchmark.fixture import BenchmarkFixture # type: ignore @pytest.mark.parametrize( diff --git a/tests/compression/test_lzfse.py b/tests/compression/test_lzfse.py index cc52d1e..c92044f 100644 --- a/tests/compression/test_lzfse.py +++ b/tests/compression/test_lzfse.py @@ -8,7 +8,7 @@ from dissect.util.compression import lzfse if TYPE_CHECKING: - from pytest_benchmark.fixture import BenchmarkFixture + from pytest_benchmark.fixture import BenchmarkFixture # type: ignore @pytest.mark.parametrize( diff --git a/tests/compression/test_lznt1.py b/tests/compression/test_lznt1.py index 64c2278..b9dc799 100644 --- a/tests/compression/test_lznt1.py +++ b/tests/compression/test_lznt1.py @@ -8,7 +8,7 @@ from dissect.util.compression import lznt1 if TYPE_CHECKING: - from pytest_benchmark.fixture import BenchmarkFixture + from pytest_benchmark.fixture import BenchmarkFixture # type: ignore @pytest.mark.parametrize( diff --git a/tests/compression/test_lzo.py b/tests/compression/test_lzo.py index f5b4c0b..6262f6e 100644 --- a/tests/compression/test_lzo.py +++ b/tests/compression/test_lzo.py @@ -8,7 +8,7 @@ if TYPE_CHECKING: from types import ModuleType - from pytest_benchmark.fixture import BenchmarkFixture + from pytest_benchmark.fixture import BenchmarkFixture # type: ignore @pytest.mark.parametrize( diff --git a/tests/compression/test_lzvn.py b/tests/compression/test_lzvn.py index 6d8d3e5..c89f586 100644 --- a/tests/compression/test_lzvn.py +++ b/tests/compression/test_lzvn.py @@ -8,7 +8,7 @@ from dissect.util.compression import lzvn if TYPE_CHECKING: - from pytest_benchmark.fixture import BenchmarkFixture + from pytest_benchmark.fixture import BenchmarkFixture # type: ignore @pytest.mark.parametrize( diff --git a/tests/compression/test_lzxpress.py b/tests/compression/test_lzxpress.py index 09ff0a4..ec3250f 100644 --- a/tests/compression/test_lzxpress.py +++ b/tests/compression/test_lzxpress.py @@ -8,7 +8,7 @@ from dissect.util.compression import lzxpress if TYPE_CHECKING: - from pytest_benchmark.fixture import BenchmarkFixture + from pytest_benchmark.fixture import BenchmarkFixture # type: ignore @pytest.mark.parametrize( diff --git a/tests/compression/test_lzxpress_huffman.py b/tests/compression/test_lzxpress_huffman.py index f5434b2..49529d0 100644 --- a/tests/compression/test_lzxpress_huffman.py +++ b/tests/compression/test_lzxpress_huffman.py @@ -8,7 +8,7 @@ from dissect.util.compression import lzxpress_huffman if TYPE_CHECKING: - from pytest_benchmark.fixture import BenchmarkFixture + from pytest_benchmark.fixture import BenchmarkFixture # type: ignore @pytest.mark.parametrize( diff --git a/tests/compression/test_sevenbit.py b/tests/compression/test_sevenbit.py index 6a3ed8f..9507780 100644 --- a/tests/compression/test_sevenbit.py +++ b/tests/compression/test_sevenbit.py @@ -7,7 +7,7 @@ from dissect.util.compression import sevenbit if TYPE_CHECKING: - from pytest_benchmark.fixture import BenchmarkFixture + from pytest_benchmark.fixture import BenchmarkFixture # type: ignore def test_sevenbit_compress() -> None: diff --git a/tests/conftest.py b/tests/conftest.py index 205ae1e..6dda30a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,5 +1,6 @@ import importlib.util from types import ModuleType +from typing import cast import pytest @@ -26,13 +27,13 @@ def pytest_addoption(parser: pytest.Parser) -> None: def _native_or_python(module: ModuleType, name: str, request: pytest.FixtureRequest) -> ModuleType: if request.param: - if not (module := getattr(module, f"{name}_native", None)): + if not (module := cast("ModuleType", getattr(module, f"{name}_native", None))): (pytest.fail if request.config.getoption("--force-native") else pytest.skip)( "_native module is unavailable" ) return module - return getattr(module, f"{name}_python", None) + return getattr(module, f"{name}_python") @pytest.fixture(scope="session", params=[True, False], ids=["native", "python"]) diff --git a/tests/test_cpio.py b/tests/test_cpio.py index 8d8d057..ca15b5a 100644 --- a/tests/test_cpio.py +++ b/tests/test_cpio.py @@ -25,13 +25,15 @@ def _verify_archive(archive: TarFile) -> None: assert small_file.name == "small-file" assert small_file.size == 9 assert small_file.isfile() - assert archive.extractfile(small_file).read() == b"contents\n" + assert (fh := archive.extractfile(small_file)) + assert fh.read() == b"contents\n" large_file = archive.getmember("large-file") assert large_file.name == "large-file" assert large_file.size == 0x3FC000 assert small_file.isfile() - assert archive.extractfile(large_file).read() == b"".join([bytes([i] * 4096) for i in range(255)]) * 4 + assert (fh := archive.extractfile(large_file)) + assert fh.read() == b"".join([bytes([i] * 4096) for i in range(255)]) * 4 symlink_1 = archive.getmember("symlink-1") assert symlink_1.issym() diff --git a/tests/test_hash.py b/tests/test_hash.py index 3a57ddb..d45e441 100644 --- a/tests/test_hash.py +++ b/tests/test_hash.py @@ -9,7 +9,7 @@ if TYPE_CHECKING: from types import ModuleType - from pytest_benchmark.fixture import BenchmarkFixture + from pytest_benchmark.fixture import BenchmarkFixture # type: ignore def test_crc32(crc32c: ModuleType) -> None: diff --git a/tests/test_plist.py b/tests/test_plist.py index e3cc178..a2f3842 100644 --- a/tests/test_plist.py +++ b/tests/test_plist.py @@ -3,12 +3,14 @@ import datetime import sys import uuid +from io import BytesIO from plistlib import UID +from typing import Any, cast from unittest.mock import patch import pytest -from dissect.util.plist import NSKeyedArchiver +from dissect.util.plist import NSKeyedArchiver, NSObject @pytest.mark.skipif(sys.version_info < (3, 8), reason="requires python3.8 or higher") @@ -132,10 +134,10 @@ def test_plist_nskeyedarchiver() -> None: ], } with patch("plistlib.load", return_value=data): - obj = NSKeyedArchiver(None) + obj = NSKeyedArchiver(BytesIO(b"")) assert "root" in obj.top - root = obj["root"] + root = cast("NSObject", obj["root"]) assert root._classname == "TestObject" assert root.Null is None assert root.Integer == 1337 @@ -147,4 +149,4 @@ def test_plist_nskeyedarchiver() -> None: assert root.URL == "http://base/relative" assert root.URLBaseless == "relative" assert root.Array == root.Set == [1, "TestString"] - assert list(root.Dict.items()) == [("DictKey", "TestString")] + assert list(cast("dict[Any, Any]", root.Dict).items()) == [("DictKey", "TestString")] diff --git a/tests/test_sid.py b/tests/test_sid.py index bea0beb..fc13d75 100644 --- a/tests/test_sid.py +++ b/tests/test_sid.py @@ -8,10 +8,10 @@ from dissect.util import sid if TYPE_CHECKING: - from pytest_benchmark.fixture import BenchmarkFixture + from pytest_benchmark.fixture import BenchmarkFixture # type: ignore -def id_fn(val: bytes | str) -> str: +def id_fn(val: bytes | str | None) -> str: if isinstance(val, io.BytesIO): val = val.getvalue() diff --git a/tests/test_stream.py b/tests/test_stream.py index 0aae125..91e970b 100644 --- a/tests/test_stream.py +++ b/tests/test_stream.py @@ -10,7 +10,7 @@ from dissect.util import stream if TYPE_CHECKING: - from pytest_benchmark.fixture import BenchmarkFixture + from pytest_benchmark.fixture import BenchmarkFixture # type: ignore def test_range_stream() -> None: diff --git a/tox.ini b/tox.ini index ad73841..7f8c2ff 100644 --- a/tox.ini +++ b/tox.ini @@ -79,6 +79,7 @@ commands = package = skip dependency_groups = lint commands = + ty check dissect tests ruff check dissect tests ruff format --check dissect tests vermin -t=3.10- --no-tips --lint dissect tests