diff --git a/dissect/util/stream.py b/dissect/util/stream.py index db163bc..63c4ad0 100644 --- a/dissect/util/stream.py +++ b/dissect/util/stream.py @@ -2,11 +2,13 @@ import io import os +import struct import sys import zlib from bisect import bisect_left, bisect_right +from functools import lru_cache from threading import Lock -from typing import BinaryIO +from typing import BinaryIO, Callable STREAM_BUFFER_SIZE = int(os.getenv("DISSECT_STREAM_BUFFER_SIZE", io.DEFAULT_BUFFER_SIZE)) @@ -644,3 +646,78 @@ def readall(self) -> bytes: chunks.append(data) return b"".join(chunks) + + +class CompressedStream(AlignedStream): + """Create a stream from a file-like object which is compresssed. Decompressing it with the supplied decompressor. + + This is useful for compressed file formats that store the compressed data in chunks like: + * Windows Imaging (WIM) archives (LZXPRESS + Huffman, LZX) + * Windows Overlay Filter (WOF) compressed files (LZXPRESS + Huffman, LZNT1, LZX) . + * Hiberfil.sys files (LZXPRESS). + + Args: + fh: The source file-like object. + offset: The offset in the source file-like object to start decompressing from. + compressed_size: The size of the compressed data. + original_size: The size of the decompressed data. + decompressor: A function that decompresses a chunk of data. + chunk_size: The size of each chunk in bytes. Default is 32 KiB. + chunks: A tuple of offsets to each chunk of compressed data. + """ + + def __init__( + self, + fh: BinaryIO, + offset: int, + compressed_size: int, + original_size: int, + decompressor: Callable[[bytes], bytes], + chunk_size: int = 32 * 1024, + chunks: tuple[int, ...] | None = None, + ): + self.fh = fh + self.offset = offset + self.compressed_size = compressed_size + self.original_size = original_size + self.decompressor = decompressor + self.chunk_size = chunk_size + self.chunks = chunks or (offset,) + + self._read_chunk = lru_cache(32)(self._read_chunk) + super().__init__(self.original_size) + + def _read(self, offset: int, length: int) -> bytes: + result = [] + + num_chunks = len(self.chunks) + chunk, offset_in_chunk = divmod(offset, self.chunk_size) + + while length: + if chunk >= num_chunks: + # We somehow requested more data than we have runs for + break + + chunk_offset = self.chunks[chunk] + if chunk < num_chunks - 1: + next_chunk_offset = self.chunks[chunk + 1] + chunk_remaining = self.chunk_size - offset_in_chunk + else: + next_chunk_offset = self.compressed_size + chunk_remaining = (self.original_size - (chunk * self.chunk_size)) - offset_in_chunk + + read_length = min(chunk_remaining, length) + + buf = self._read_chunk(chunk_offset, next_chunk_offset - chunk_offset) + result.append(buf[offset_in_chunk : offset_in_chunk + read_length]) + + length -= read_length + offset += read_length + chunk += 1 + + return b"".join(result) + + def _read_chunk(self, offset: int, size: int) -> bytes: + self.fh.seek(offset) + buf = self.fh.read(size) + return self.decompressor(buf) diff --git a/tests/test_stream.py b/tests/test_stream.py index bd18cbe..83eae13 100644 --- a/tests/test_stream.py +++ b/tests/test_stream.py @@ -5,6 +5,7 @@ import pytest from dissect.util import stream +from dissect.util.compression import lzxpress_huffman def test_range_stream() -> None: @@ -217,3 +218,37 @@ def test_zlib_stream() -> None: fh.seek(0) assert fh.read() == data + + +def test_compressed_stream_lzxpress_huffman() -> None: + data = io.BytesIO( + bytes.fromhex( + "4034030000000000000000000000000000000000000000000000000000000000" + "0000000000000000000000000000000000000000000000000000000000000000" + "0000000000000000000000000000000000000000000000000000000000000000" + "0000000000000000000000000000000000000000000000000000000000000000" + "0300000000000010000000000000000000000000000000000000000000000000" + "0000000000000000000000000000000000000000000000000000000000000000" + "0000000000000000000000000000000000000000000000000000000000000000" + "0000000000000000000000000000000000000000000000000000000000000000" + "a2e700b0fffc0ffffc07fffc030000fffc03" + ) + ) + + size = 8192 + compressed_size = 274 + + fh = stream.CompressedStream(data, 0, compressed_size, size, lzxpress_huffman.decompress) + assert fh.chunks == (0,) + + fh.seek(0) + assert fh.read(4096) == b"\x01" * 4096 + + fh.seek(2048) + assert fh.read(4096) == b"\x01" * 2048 + b"\x02" * 2048 + + fh.seek(6144) + assert fh.read(2048) == b"\x03" * 1024 + b"\x04" * 1024 + + fh.seek(0) + assert fh.read() == b"\x01" * 4096 + b"\x02" * 2048 + b"\x03" * 1024 + b"\x04" * 1024