diff --git a/ofrak_core/ofrak/core/binary.py b/ofrak_core/ofrak/core/binary.py index aa1f107d9..55f446324 100644 --- a/ofrak_core/ofrak/core/binary.py +++ b/ofrak_core/ofrak/core/binary.py @@ -41,9 +41,8 @@ class BinaryExtendModifier(Modifier[BinaryExtendConfig]): async def modify(self, resource: Resource, config: BinaryExtendConfig): if len(config.content) == 0: raise ValueError("Content of the extended space not provided") - data = await resource.get_data() - data += config.content - resource.queue_patch(Range(0, await resource.get_data_length()), data) + orig_data_length = await resource.get_data_length() + resource.queue_patch(Range(orig_data_length, orig_data_length), config.content) @dataclass diff --git a/ofrak_core/ofrak/core/bzip2.py b/ofrak_core/ofrak/core/bzip2.py index 03ea6487f..1ef53307c 100644 --- a/ofrak_core/ofrak/core/bzip2.py +++ b/ofrak_core/ofrak/core/bzip2.py @@ -36,8 +36,8 @@ async def unpack(self, resource: Resource, config=None): :param resource: :param config: """ - resource_data = await resource.get_data() - decompressed_data = bz2.decompress(resource_data) + with await resource.get_data_memoryview() as resource_data: + decompressed_data = bz2.decompress(resource_data) await resource.create_child( tags=(GenericBinary,), data=decompressed_data, @@ -59,7 +59,8 @@ async def pack(self, resource: Resource, config=None): :param config: """ bzip2_child = await resource.get_only_child() - bzip2_compressed = bz2.compress(await bzip2_child.get_data()) + with await bzip2_child.get_data_memoryview() as buffer: + bzip2_compressed = bz2.compress(buffer) original_size = await resource.get_data_length() resource.queue_patch(Range(0, original_size), bzip2_compressed) diff --git a/ofrak_core/ofrak/core/checksum.py b/ofrak_core/ofrak/core/checksum.py index 895c09693..2bec3e097 100644 --- a/ofrak_core/ofrak/core/checksum.py +++ b/ofrak_core/ofrak/core/checksum.py @@ -23,9 +23,9 @@ class Sha256Analyzer(Analyzer[None, Sha256Attributes]): outputs = (Sha256Attributes,) async def analyze(self, resource: Resource, config=None) -> Sha256Attributes: - data = await resource.get_data() sha256 = hashlib.sha256() - sha256.update(data) + with await resource.get_data_memoryview() as data: + sha256.update(data) return Sha256Attributes(sha256.hexdigest()) @@ -43,7 +43,7 @@ class Md5Analyzer(Analyzer[None, Md5Attributes]): outputs = (Md5Attributes,) async def analyze(self, resource: Resource, config=None) -> Md5Attributes: - data = await resource.get_data() md5 = hashlib.md5() - md5.update(data) + with await resource.get_data_memoryview() as data: + md5.update(data) return Md5Attributes(md5.hexdigest()) diff --git a/ofrak_core/ofrak/core/comments.py b/ofrak_core/ofrak/core/comments.py index 1d9139571..afd1096a5 100644 --- a/ofrak_core/ofrak/core/comments.py +++ b/ofrak_core/ofrak/core/comments.py @@ -34,7 +34,7 @@ async def modify(self, resource: Resource, config: AddCommentModifierConfig) -> # Verify that the given range is valid for the given resource. config_range = config.comment[0] if config_range is not None: - if config_range.start < 0 or config_range.end > len(await resource.get_data()): + if config_range.start < 0 or config_range.end > await resource.get_data_length(): raise ValueError( f"Range {config_range} is outside the bounds of " f"resource {resource.get_id().hex()}" diff --git a/ofrak_core/ofrak/core/cpio.py b/ofrak_core/ofrak/core/cpio.py index 685701990..ed58dad5c 100644 --- a/ofrak_core/ofrak/core/cpio.py +++ b/ofrak_core/ofrak/core/cpio.py @@ -86,7 +86,6 @@ class CpioUnpacker(Unpacker[None]): async def unpack(self, resource: Resource, config=None): cpio_v = await resource.view_as(CpioFilesystem) - resource_data = await cpio_v.resource.get_data() with tempfile.TemporaryDirectory() as temp_flush_dir: cmd = [ "cpio", @@ -99,7 +98,8 @@ async def unpack(self, resource: Resource, config=None): stderr=asyncio.subprocess.PIPE, cwd=temp_flush_dir, ) - await proc.communicate(input=resource_data) + with await resource.get_data_memoryview() as resource_data: + await proc.communicate(input=resource_data) # if proc.returncode: # raise CalledProcessError(returncode=proc.returncode, cmd=cmd) await cpio_v.initialize_from_disk(temp_flush_dir) diff --git a/ofrak_core/ofrak/core/data.py b/ofrak_core/ofrak/core/data.py index a1d4c4603..c0ad1ac28 100644 --- a/ofrak_core/ofrak/core/data.py +++ b/ofrak_core/ofrak/core/data.py @@ -23,12 +23,12 @@ class DataWord(MemoryRegion): xrefs_to: Tuple[int, ...] async def get_value_unsigned(self) -> int: - data = await self.resource.get_data() - return struct.unpack(self.format_string.upper(), data)[0] + with await self.resource.get_data_memoryview() as data: + return struct.unpack(self.format_string.upper(), data)[0] async def get_value_signed(self) -> int: - data = await self.resource.get_data() - return struct.unpack(self.format_string.lower(), data)[0] + with await self.resource.get_data_memoryview() as data: + return struct.unpack(self.format_string.lower(), data)[0] @dataclass(**ResourceAttributes.DATACLASS_PARAMS) diff --git a/ofrak_core/ofrak/core/dtb.py b/ofrak_core/ofrak/core/dtb.py index aa02549e3..c3a3358b7 100644 --- a/ofrak_core/ofrak/core/dtb.py +++ b/ofrak_core/ofrak/core/dtb.py @@ -102,30 +102,30 @@ class DtbHeaderAnalyzer(Analyzer[None, DtbHeader]): outputs = (DtbHeader,) async def analyze(self, resource: Resource, config: None) -> DtbHeader: - header_data = await resource.get_data() - ( - dtb_magic, - totalsize, - off_dt_struct, - off_dt_strings, - off_mem_rsvmap, - version, - last_comp_version, - ) = struct.unpack(">IIIIIII", header_data[:28]) - assert dtb_magic == DTB_MAGIC_SIGNATURE, ( - f"DTB Magic bytes not matching." - f"Expected: {DTB_MAGIC_SIGNATURE} " - f"Unpacked: {dtb_magic}" - ) - boot_cpuid_phys = 0 - dtb_strings_size = 0 - dtb_struct_size = 0 - if version >= 2: - boot_cpuid_phys = struct.unpack(">I", header_data[28:32])[0] - if version >= 3: - dtb_strings_size = struct.unpack(">I", header_data[32:36])[0] - if version >= 17: - dtb_struct_size = struct.unpack(">I", header_data[36:40])[0] + with await resource.get_data_memoryview(Range(0, 40)) as header_data: + ( + dtb_magic, + totalsize, + off_dt_struct, + off_dt_strings, + off_mem_rsvmap, + version, + last_comp_version, + ) = struct.unpack(">IIIIIII", header_data[:28]) + assert dtb_magic == DTB_MAGIC_SIGNATURE, ( + f"DTB Magic bytes not matching." + f"Expected: {DTB_MAGIC_SIGNATURE} " + f"Unpacked: {dtb_magic}" + ) + boot_cpuid_phys = 0 + dtb_strings_size = 0 + dtb_struct_size = 0 + if version >= 2: + boot_cpuid_phys = struct.unpack(">I", header_data[28:32])[0] + if version >= 3: + dtb_strings_size = struct.unpack(">I", header_data[32:36])[0] + if version >= 17: + dtb_struct_size = struct.unpack(">I", header_data[36:40])[0] return DtbHeader( dtb_magic, diff --git a/ofrak_core/ofrak/core/elf/analyzer.py b/ofrak_core/ofrak/core/elf/analyzer.py index 2373316c3..b39af9627 100644 --- a/ofrak_core/ofrak/core/elf/analyzer.py +++ b/ofrak_core/ofrak/core/elf/analyzer.py @@ -50,8 +50,8 @@ class ElfBasicHeaderAttributesAnalyzer(Analyzer[None, ElfBasicHeader]): outputs = (ElfBasicHeader,) async def analyze(self, resource: Resource, config=None) -> ElfBasicHeader: - tmp = await resource.get_data() - deserializer = BinaryDeserializer(io.BytesIO(tmp)) + with await resource.get_data_memoryview() as tmp: + deserializer = BinaryDeserializer(io.BytesIO(tmp)) ( ei_magic, ei_class, diff --git a/ofrak_core/ofrak/core/entropy/entropy.py b/ofrak_core/ofrak/core/entropy/entropy.py index 803ac93d8..301bcad28 100644 --- a/ofrak_core/ofrak/core/entropy/entropy.py +++ b/ofrak_core/ofrak/core/entropy/entropy.py @@ -76,7 +76,7 @@ async def analyze(self, resource: Resource, config=None, depth=0) -> DataSummary def sample_entropy( - data: bytes, resource_id: bytes, window_size=256, max_samples=2**20 + data: bytearray, resource_id: bytes, window_size=256, max_samples=2**20 ) -> bytes: # pragma: no cover """ Return a list of entropy values where each value represents the Shannon entropy of the byte diff --git a/ofrak_core/ofrak/core/entropy/entropy_c.py b/ofrak_core/ofrak/core/entropy/entropy_c.py index 1987af0e9..85fb6306b 100644 --- a/ofrak_core/ofrak/core/entropy/entropy_c.py +++ b/ofrak_core/ofrak/core/entropy/entropy_c.py @@ -33,7 +33,8 @@ def entropy_c( if len(data) <= window_size: return b"" entropy = ctypes.create_string_buffer(len(data) - window_size) - errval = C_ENTROPY_FUNC(data, len(data), entropy, window_size, C_LOG_TYPE(log_percent)) + buffer = (ctypes.c_char * len(data)).from_buffer_copy(data) + errval = C_ENTROPY_FUNC(buffer, len(data), entropy, window_size, C_LOG_TYPE(log_percent)) if errval != 0: raise ValueError("Bad input to entropy function.") - return bytes(entropy.raw) + return entropy.raw diff --git a/ofrak_core/ofrak/core/entropy/entropy_py.py b/ofrak_core/ofrak/core/entropy/entropy_py.py index a1c68b60f..938cdd928 100644 --- a/ofrak_core/ofrak/core/entropy/entropy_py.py +++ b/ofrak_core/ofrak/core/entropy/entropy_py.py @@ -25,7 +25,7 @@ def entropy_py( histogram[b] += 1 # Calculate the entropy using a sliding window - entropy = [0] * (len(data) - window_size) + entropy = bytearray(max(0, len(data) - window_size)) last_percent_logged = 0 for i in range(len(entropy)): entropy[i] = math.floor(255 * _shannon_entropy(histogram, window_size)) @@ -35,7 +35,7 @@ def entropy_py( if percent > last_percent_logged and percent % 10 == 0: log_percent(percent) last_percent_logged = percent - return bytes(entropy) + return entropy def _shannon_entropy(distribution: List[int], window_size: int) -> float: diff --git a/ofrak_core/ofrak/core/filesystem.py b/ofrak_core/ofrak/core/filesystem.py index d62c8836f..e73fc15d0 100644 --- a/ofrak_core/ofrak/core/filesystem.py +++ b/ofrak_core/ofrak/core/filesystem.py @@ -214,7 +214,7 @@ async def flush_to_disk(self, root_path: str = ".", filename: Optional[str] = No elif self.is_file(): file_name = os.path.join(root_path, entry_path) with open(file_name, "wb") as f: - f.write(await self.resource.get_data()) + await self.resource.write_to(f, pack=False) self.apply_stat_attrs(file_name) elif self.is_device(): device_name = os.path.join(root_path, entry_path) diff --git a/ofrak_core/ofrak/core/flash.py b/ofrak_core/ofrak/core/flash.py index b89e7fd8f..5a50dc2b4 100644 --- a/ofrak_core/ofrak/core/flash.py +++ b/ofrak_core/ofrak/core/flash.py @@ -426,80 +426,83 @@ async def unpack(self, resource: Resource, config=None): oob_resource = resource # Parent FlashEccResource is created, redefine data to limited scope - data = await oob_resource.get_data() - data_len = len(data) - - # Now add children blocks until we reach the tail block - offset = 0 - only_data = list() - only_ecc = list() - for block in flash_attr.iterate_through_all_blocks(data_len, True): - block_size = flash_attr.get_block_size(block) - block_end_offset = offset + block_size - if block_end_offset > data_len: - LOGGER.info( - f"Block offset {block_end_offset} is {block_end_offset - data_len} larger " - f"than {data_len}. In this case unpacking is best effort and end of unpacked " - f"child might not be accurate." - ) - break - block_range = Range(offset, block_end_offset) - block_data = await oob_resource.get_data(range=block_range) - - # Iterate through every field in block, dealing with ECC and DATA - block_ecc_range = None - block_data_range = None - field_offset = 0 - for field_index, field in enumerate(block): - field_range = Range(field_offset, field_offset + field.size) - - # We must check all blocks anyway so deal with ECC here - if field.field_type == FlashFieldType.ECC: - block_ecc_range = field_range - cur_block_ecc = block_data[block_ecc_range.start : block_ecc_range.end] - only_ecc.append(cur_block_ecc) - # Add hash of everything up to the ECC to our dict for faster packing - block_data_hash = md5(block_data[: block_ecc_range.start]).digest() - DATA_HASHES[block_data_hash] = cur_block_ecc - - if field.field_type == FlashFieldType.DATA: - block_data_range = field_range - # Get next ECC range - future_offset = field_offset - block_list = list(block) - for future_field in block_list[field_index:]: - if future_field.field_type == FlashFieldType.ECC: - block_ecc_range = Range( - future_offset, future_offset + future_field.size - ) - future_offset += future_field.size - - if block_ecc_range is not None: - # Try decoding/correcting with ECC, report any error - try: - # Assumes that data comes before ECC - if (ecc_attr is not None) and (ecc_attr.ecc_class is not None): - only_data.append( - ecc_attr.ecc_class.decode(block_data[: block_ecc_range.end])[ - block_data_range.start : block_data_range.end - ] - ) - else: - raise UnpackerError( - "Tried to correct with ECC without providing an ecc_class in FlashEccAttributes" + with await oob_resource.get_data_memoryview() as data: + data_len = len(data) + + # Now add children blocks until we reach the tail block + offset = 0 + only_data = bytearray() + only_ecc = bytearray() + for block in flash_attr.iterate_through_all_blocks(data_len, True): + block_size = flash_attr.get_block_size(block) + block_end_offset = offset + block_size + if block_end_offset > data_len: + LOGGER.info( + f"Block offset {block_end_offset} is {block_end_offset - data_len} larger " + f"than {data_len}. In this case unpacking is best effort and end of unpacked " + f"child might not be accurate." + ) + break + with data[offset:block_end_offset] as block_memview: + block_data = bytes(block_memview) + + # Iterate through every field in block, dealing with ECC and DATA + block_ecc_range = None + block_data_range = None + field_offset = 0 + for field_index, field in enumerate(block): + field_range = Range(field_offset, field_offset + field.size) + + # We must check all blocks anyway so deal with ECC here + if field.field_type == FlashFieldType.ECC: + block_ecc_range = field_range + cur_block_ecc = block_data[block_ecc_range.start : block_ecc_range.end] + only_ecc.extend(cur_block_ecc) + # Add hash of everything up to the ECC to our dict for faster packing + block_data_hash = md5(block_data[: block_ecc_range.start]).digest() + DATA_HASHES[block_data_hash] = cur_block_ecc + + if field.field_type == FlashFieldType.DATA: + block_data_range = field_range + # Get next ECC range + future_offset = field_offset + block_list = list(block) + for future_field in block_list[field_index:]: + if future_field.field_type == FlashFieldType.ECC: + block_ecc_range = Range( + future_offset, future_offset + future_field.size ) - except EccError: - raise UnpackerError("ECC correction failed") - else: - # No ECC found in block, just add the data directly - only_data.append(block_data[block_data_range.start : block_data_range.end]) - field_offset += field.size - offset += block_size - + future_offset += future_field.size + + if block_ecc_range is not None: + # Try decoding/correcting with ECC, report any error + try: + # Assumes that data comes before ECC + if (ecc_attr is not None) and (ecc_attr.ecc_class is not None): + only_data.extend( + ecc_attr.ecc_class.decode( + block_data[: block_ecc_range.end] + )[block_data_range.start : block_data_range.end] + ) + else: + raise UnpackerError( + "Tried to correct with ECC without providing an ecc_class in FlashEccAttributes" + ) + except EccError: + raise UnpackerError("ECC correction failed") + else: + # No ECC found in block, just add the data directly + only_data.extend( + block_data[block_data_range.start : block_data_range.end] + ) + field_offset += field.size + offset += block_size + if not only_data: + only_data = bytearray(data) # Add all block data to logical resource for recursive unpacking await oob_resource.create_child( tags=(FlashLogicalDataResource,), - data=b"".join(only_data) if only_data else data, + data=only_data, attributes=[ flash_attr, ], @@ -507,7 +510,7 @@ async def unpack(self, resource: Resource, config=None): if ecc_attr is not None: await oob_resource.create_child( tags=(FlashLogicalEccResource,), - data=b"".join(only_ecc), + data=only_ecc, attributes=[ ecc_attr, ], diff --git a/ofrak_core/ofrak/core/gzip.py b/ofrak_core/ofrak/core/gzip.py index 66df3d5ee..db6811bf3 100644 --- a/ofrak_core/ofrak/core/gzip.py +++ b/ofrak_core/ofrak/core/gzip.py @@ -52,12 +52,12 @@ class GzipUnpacker(Unpacker[None]): external_dependencies = (PIGZ,) async def unpack(self, resource: Resource, config=None): - data = await resource.get_data() - unpacked_data = await self.unpack_with_zlib_module(data) + with await resource.get_data_memoryview() as data: + unpacked_data = self.unpack_with_zlib_module(data) return await resource.create_child(tags=(GenericBinary,), data=unpacked_data) @staticmethod - async def unpack_with_zlib_module(data: bytes) -> bytes: + def unpack_with_zlib_module(data: bytes) -> bytes: # We use zlib.decompressobj instead of the gzip module to decompress # because of a bug that causes gzip to raise BadGzipFile if there's # trailing garbage after a compressed file instead of correctly ignoring it @@ -67,7 +67,7 @@ async def unpack_with_zlib_module(data: bytes) -> bytes: # a loop and concatenate them in the end. \037\213 are magic bytes # indicating the start of a gzip header. chunks = [] - while data.startswith(b"\037\213"): + while data[:2] == b"\037\213": # wbits > 16 handles the gzip header and footer decompressor = zlib.decompressobj(wbits=16 + zlib.MAX_WBITS) chunks.append(decompressor.decompress(data)) @@ -91,25 +91,24 @@ class GzipPacker(Packer[None]): async def pack(self, resource: Resource, config=None): gzip_view = await resource.view_as(GzipData) gzip_child_r = await gzip_view.get_file() - data = await gzip_child_r.get_data() - - if len(data) >= 1024 * 1024 and await PIGZInstalled.is_pigz_installed(): - packed_data = await self.pack_with_pigz(data) - else: - packed_data = await self.pack_with_zlib_module(data) + with await gzip_child_r.get_data_memoryview() as data: + if len(data) >= 1024 * 1024 and await PIGZInstalled.is_pigz_installed(): + packed_data = await self.pack_with_pigz(data) + else: + packed_data = self.pack_with_zlib_module(data) original_gzip_size = await gzip_view.resource.get_data_length() resource.queue_patch(Range(0, original_gzip_size), data=packed_data) @staticmethod - async def pack_with_zlib_module(data: bytes) -> bytes: + def pack_with_zlib_module(data: memoryview) -> bytes: compressor = zlib.compressobj(wbits=16 + zlib.MAX_WBITS) result = compressor.compress(data) result += compressor.flush() return result @staticmethod - async def pack_with_pigz(data: bytes) -> bytes: + async def pack_with_pigz(data: memoryview) -> bytes: with tempfile.NamedTemporaryFile(delete_on_close=False) as uncompressed_file: uncompressed_file.write(data) uncompressed_file.close() diff --git a/ofrak_core/ofrak/core/iso9660.py b/ofrak_core/ofrak/core/iso9660.py index 4549846b3..d7844f822 100644 --- a/ofrak_core/ofrak/core/iso9660.py +++ b/ofrak_core/ofrak/core/iso9660.py @@ -113,7 +113,8 @@ async def analyze(self, resource: Resource, config=None): udf_version = None iso = PyCdlib() - iso.open_fp(BytesIO(await resource.get_data())) + with await resource.get_data_memoryview() as iso_data: + iso.open_fp(BytesIO(iso_data)) interchange_level = iso.interchange_level has_joliet = iso.has_joliet() @@ -166,14 +167,13 @@ class ISO9660Unpacker(Unpacker[None]): children = (ISO9660Entry,) async def unpack(self, resource: Resource, config=None): - iso_data = await resource.get_data() - iso_attributes = await resource.analyze(ISO9660ImageAttributes) resource.add_attributes(iso_attributes) iso_resource = await resource.view_as(ISO9660Image) iso = PyCdlib() - iso.open_fp(BytesIO(iso_data)) + with await resource.get_data_memoryview() as iso_data: + iso.open_fp(BytesIO(iso_data)) if iso_attributes.has_joliet: facade = iso.get_joliet_facade() diff --git a/ofrak_core/ofrak/core/lzma.py b/ofrak_core/ofrak/core/lzma.py index f9b7db4ba..2f339b9de 100644 --- a/ofrak_core/ofrak/core/lzma.py +++ b/ofrak_core/ofrak/core/lzma.py @@ -1,6 +1,5 @@ import logging import lzma -from io import BytesIO from typing import Union from ofrak.component.packer import Packer @@ -41,8 +40,6 @@ class LzmaUnpacker(Unpacker[None]): children = (GenericBinary,) async def unpack(self, resource: Resource, config=None): - file_data = BytesIO(await resource.get_data()) - format = lzma.FORMAT_AUTO if resource.has_tag(XzData): @@ -51,13 +48,13 @@ async def unpack(self, resource: Resource, config=None): format = lzma.FORMAT_ALONE lzma_entry_data = None - compressed_data = file_data.read() - try: - lzma_entry_data = lzma.decompress(compressed_data, format) - except lzma.LZMAError: - LOGGER.info("Initial LZMA decompression failed. Trying with null bytes stripped") - lzma_entry_data = lzma.decompress(compressed_data.rstrip(b"\x00"), format) + with await resource.get_data_memoryview() as compressed_data: + try: + lzma_entry_data = lzma.decompress(compressed_data, format) + except lzma.LZMAError: + LOGGER.info("Initial LZMA decompression failed. Trying with null bytes stripped") + lzma_entry_data = lzma.decompress(compressed_data.tobytes().rstrip(b"\x00"), format) if lzma_entry_data is not None: await resource.create_child( @@ -80,7 +77,8 @@ async def pack(self, resource: Resource, config=None): lzma_file: Union[XzData, LzmaData] = await resource.view_as(tag) lzma_child = await lzma_file.get_child() - lzma_compressed = lzma.compress(await lzma_child.resource.get_data(), lzma_format) + with await lzma_child.resource.get_data_memoryview() as data: + lzma_compressed = lzma.compress(data, lzma_format) original_size = await lzma_file.resource.get_data_length() resource.queue_patch(Range(0, original_size), lzma_compressed) diff --git a/ofrak_core/ofrak/core/lzo.py b/ofrak_core/ofrak/core/lzo.py index ac22da72c..250b55e45 100644 --- a/ofrak_core/ofrak/core/lzo.py +++ b/ofrak_core/ofrak/core/lzo.py @@ -43,7 +43,8 @@ async def unpack(self, resource: Resource, config: ComponentConfig = None) -> No stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE, ) - stdout, stderr = await proc.communicate(await resource.get_data()) + with await resource.get_data_memoryview() as data: + stdout, stderr = await proc.communicate(data) if proc.returncode: raise CalledProcessError(returncode=proc.returncode, cmd=cmd) @@ -60,8 +61,6 @@ class LzoPacker(Packer[None]): async def pack(self, resource: Resource, config: ComponentConfig = None): lzo_view = await resource.view_as(LzoData) - child_file = await lzo_view.get_child() - uncompressed_data = await child_file.resource.get_data() cmd = ["lzop", "-f"] proc = await asyncio.create_subprocess_exec( @@ -70,7 +69,9 @@ async def pack(self, resource: Resource, config: ComponentConfig = None): stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE, ) - stdout, stderr = await proc.communicate(uncompressed_data) + child_file = await lzo_view.get_child() + with await child_file.resource.get_data_memoryview() as data: + stdout, stderr = await proc.communicate(data) if proc.returncode: raise CalledProcessError(returncode=proc.returncode, cmd=cmd) diff --git a/ofrak_core/ofrak/core/magic.py b/ofrak_core/ofrak/core/magic.py index e600ed01d..ab5e092b1 100644 --- a/ofrak_core/ofrak/core/magic.py +++ b/ofrak_core/ofrak/core/magic.py @@ -1,6 +1,7 @@ import logging from dataclasses import dataclass from typing import Callable, Dict, Iterable, Union +from ctypes import c_char from ofrak.component.abstract import ComponentMissingDependencyError @@ -64,13 +65,20 @@ class MagicAnalyzer(Analyzer[None, Magic]): external_dependencies = (LIBMAGIC_DEP,) async def analyze(self, resource: Resource, config=None) -> Magic: - data = await resource.get_data() if not MAGIC_INSTALLED: raise ComponentMissingDependencyError(self, LIBMAGIC_DEP) else: - magic_mime = magic.from_buffer(data, mime=True) - magic_description = magic.from_buffer(data) - return Magic(magic_mime, magic_description) + with await resource.get_data_memoryview(force_readonly=False) as buffer: + c_char_array = c_char * len(buffer) + if buffer.readonly: + buffer_ctypes = c_char_array.from_buffer_copy(buffer) + else: + buffer_ctypes = c_char_array.from_buffer(buffer) + + # python-magic will accept the ctypes Arrays even though its type hints don't + magic_mime = magic.from_buffer(buffer_ctypes, mime=True) # type: ignore + magic_description = magic.from_buffer(buffer_ctypes) # type: ignore + return Magic(magic_mime, magic_description) class MagicMimeIdentifier(Identifier[None]): diff --git a/ofrak_core/ofrak/core/openwrt.py b/ofrak_core/ofrak/core/openwrt.py index c044ba06d..0e9ee655c 100644 --- a/ofrak_core/ofrak/core/openwrt.py +++ b/ofrak_core/ofrak/core/openwrt.py @@ -158,42 +158,42 @@ class OpenWrtTrxUnpacker(Unpacker[None]): ) async def unpack(self, resource: Resource, config=None): - data = await resource.get_data() - # Peek into TRX version to know how big the header is - trx_version = OpenWrtTrxVersion(struct.unpack(" OpenWrtTrxHeader: - tmp = await resource.get_data() - deserializer = BinaryDeserializer( - io.BytesIO(tmp), - endianness=Endianness.LITTLE_ENDIAN, - word_size=4, - ) + with await resource.get_data_memoryview() as tmp: + deserializer = BinaryDeserializer( + io.BytesIO(tmp), + endianness=Endianness.LITTLE_ENDIAN, + word_size=4, + ) deserialized = deserializer.unpack_multiple("IIIHH") ( trx_magic, @@ -344,8 +344,10 @@ async def pack(self, resource: Resource, config=None): ], key=lambda x: x[0].start, ) - repacked_data_l = [await child.get_data() for _, child in children_by_offset] - repacked_data_b = b"".join(repacked_data_l) + repacked_data_b = bytearray() + for _, child in children_by_offset: + with await child.get_data_memoryview() as child_data: + repacked_data_b.extend(child_data) trx_length = header.get_header_length() + len(repacked_data_b) offsets = [r.start for r, _ in children_by_offset] diff --git a/ofrak_core/ofrak/core/seven_zip.py b/ofrak_core/ofrak/core/seven_zip.py index edd2d071a..0b5883ad0 100644 --- a/ofrak_core/ofrak/core/seven_zip.py +++ b/ofrak_core/ofrak/core/seven_zip.py @@ -38,7 +38,6 @@ class SevenZUnpacker(Unpacker[None]): async def unpack(self, resource: Resource, config=None): seven_zip_v = await resource.view_as(SevenZFilesystem) - resource_data = await seven_zip_v.resource.get_data() async with resource.temp_to_disk(suffix=".7z") as temp_path: with tempfile.TemporaryDirectory() as temp_flush_dir: cmd = [ diff --git a/ofrak_core/ofrak/core/ubi.py b/ofrak_core/ofrak/core/ubi.py index f59dc72b7..bb06a8410 100644 --- a/ofrak_core/ofrak/core/ubi.py +++ b/ofrak_core/ofrak/core/ubi.py @@ -188,9 +188,7 @@ async def unpack(self, resource: Resource, config=None): with tempfile.TemporaryDirectory() as temp_flush_dir: # flush to disk with open(f"{temp_flush_dir}/input.img", "wb") as temp_file: - resource_data = await resource.get_data() - temp_file.write(resource_data) - temp_file.flush() + await resource.write_to(temp_file, pack=False) # extract temp_file to temp_flush_dir cmd = [ diff --git a/ofrak_core/ofrak/core/ubifs.py b/ofrak_core/ofrak/core/ubifs.py index 250055125..ce21d664d 100644 --- a/ofrak_core/ofrak/core/ubifs.py +++ b/ofrak_core/ofrak/core/ubifs.py @@ -134,9 +134,7 @@ async def unpack(self, resource: Resource, config=None): with tempfile.TemporaryDirectory() as temp_flush_dir: # flush to disk with open(f"{temp_flush_dir}/input.img", "wb") as temp_file: - resource_data = await resource.get_data() - temp_file.write(resource_data) - temp_file.flush() + await resource.write_to(temp_file, pack=False) cmd = [ "ubireader_extract_files", diff --git a/ofrak_core/ofrak/core/uf2.py b/ofrak_core/ofrak/core/uf2.py index b9df86a30..94760f77d 100644 --- a/ofrak_core/ofrak/core/uf2.py +++ b/ofrak_core/ofrak/core/uf2.py @@ -99,70 +99,81 @@ async def unpack(self, resource: Resource, config=None): previous_block_no = -1 family_id = None file_num_blocks = None - block_no = 0 - - for i in range(0, data_length, 512): - data = await resource.get_data(Range(i, (i + 512))) - ( - magic_start_one, - magic_start_two, - flags, - target_addr, - payload_size, - block_no, - num_blocks, - filesize_familyID, - payload_data, - magic_end, - ) = struct.unpack("8I476sI", data) - - # basic sanity checks - if magic_start_one != UF2_MAGIC_START_ONE: - raise ValueError("Bad Start Magic") - if magic_start_two != UF2_MAGIC_START_TWO: - raise ValueError("Bad Start Magic") - if magic_end != UF2_MAGIC_END: - raise ValueError("Bad End Magic") - - if (previous_block_no - block_no) != -1: - raise ValueError("Skipped a block number") - previous_block_no = block_no - - if not file_num_blocks: - file_num_blocks = num_blocks - - if family_id is None: - family_id = filesize_familyID - else: - if family_id != filesize_familyID: - raise NotImplementedError("Multiple family IDs in file not supported") - - # unpack data - if flags & Uf2Flags.NOT_MAIN_FLASH: - # data not written to main flash - raise NotImplementedError( - "Data not written to main flash is currently not supported" - ) - elif flags & Uf2Flags.FILE_CONTAINER: - # file container - raise NotImplementedError("File containers are currently not implemented") - elif flags & Uf2Flags.FAMILY_ID_PRESENT: - data = payload_data[0:payload_size] - if len(ranges) == 0: - ranges.append((Range(target_addr, target_addr + payload_size), data)) + block_no: int = 0 + + with await resource.get_data_memoryview(Range(0, data_length)) as all_data: + for i in range(0, data_length, 512): + data: bytes + with all_data[i : i + 512] as data: + magic_start_one: int + magic_start_two: int + flags: int + target_addr: int + payload_size: int + num_blocks: int + filesize_familyID: int + payload_data: bytes + magic_end: int + ( + magic_start_one, + magic_start_two, + flags, + target_addr, + payload_size, + block_no, + num_blocks, + filesize_familyID, + payload_data, + magic_end, + ) = struct.unpack("8I476sI", data) + + # basic sanity checks + if magic_start_one != UF2_MAGIC_START_ONE: + raise ValueError("Bad Start Magic") + if magic_start_two != UF2_MAGIC_START_TWO: + raise ValueError("Bad Start Magic") + if magic_end != UF2_MAGIC_END: + raise ValueError("Bad End Magic") + + if (previous_block_no - block_no) != -1: + raise ValueError("Skipped a block number") + previous_block_no = block_no + + if not file_num_blocks: + file_num_blocks = num_blocks + + if family_id is None: + family_id = filesize_familyID else: - last_region_range, last_region_data = ranges[-1] - - # if range is adjacent, extend, otherwise start a new one - if target_addr - last_region_range.end == 0: - last_region_range.end = target_addr + payload_size - last_region_data += data - ranges[-1] = (last_region_range, last_region_data) - else: + if family_id != filesize_familyID: + raise NotImplementedError("Multiple family IDs in file not supported") + + # unpack data + if flags & Uf2Flags.NOT_MAIN_FLASH: + # data not written to main flash + raise NotImplementedError( + "Data not written to main flash is currently not supported" + ) + elif flags & Uf2Flags.FILE_CONTAINER: + # file container + raise NotImplementedError("File containers are currently not implemented") + elif flags & Uf2Flags.FAMILY_ID_PRESENT: + data = payload_data[0:payload_size] + if len(ranges) == 0: ranges.append((Range(target_addr, target_addr + payload_size), data)) - else: - # unsupported flags - raise ValueError(f"Unsupported flags {flags}") + else: + last_region_range, last_region_data = ranges[-1] + + # if range is adjacent, extend, otherwise start a new one + if target_addr - last_region_range.end == 0: + last_region_range.end = target_addr + payload_size + last_region_data += data + ranges[-1] = (last_region_range, last_region_data) + else: + ranges.append((Range(target_addr, target_addr + payload_size), data)) + else: + # unsupported flags + raise ValueError(f"Unsupported flags {flags}") # count vs 0 indexed (there are 256 blocks from 0-255) if file_num_blocks != (block_no + 1): @@ -208,14 +219,13 @@ async def pack(self, resource: Resource, config=None): ) ): memory_region = await memory_region_r.view_as(CodeRegion) - data = await memory_region_r.get_data() - data_length = await memory_region_r.get_data_length() - data_range = memory_region.vaddr_range() - addr = data_range.start + with await memory_region_r.get_data_memoryview() as data: + data_range = memory_region.vaddr_range() + addr = data_range.start - for i in range(0, data_length, 256): - payloads.append((addr + i, 256, data[i : (i + 256)])) - continue + for i in range(0, len(data), 256): + payloads.append((addr + i, 256, bytes(data[i : (i + 256)]))) + continue num_blocks = len(payloads) block_no = 0 @@ -223,21 +233,23 @@ async def pack(self, resource: Resource, config=None): file_attributes = resource.get_attributes(attributes_type=Uf2FileAttributes) family_id = file_attributes.family_id - repacked_data = b"" + repacked_data = bytearray() for target_addr, payload_size, payload_data in payloads: - repacked_data += struct.pack( - "8I476sI", - UF2_MAGIC_START_ONE, - UF2_MAGIC_START_TWO, - Uf2Flags.FAMILY_ID_PRESENT, - target_addr, - payload_size, - block_no, - num_blocks, - family_id, - payload_data + b"\x00" * (467 - payload_size), # add padding - UF2_MAGIC_END, + repacked_data.extend( + struct.pack( + "8I476sI", + UF2_MAGIC_START_ONE, + UF2_MAGIC_START_TWO, + Uf2Flags.FAMILY_ID_PRESENT, + target_addr, + payload_size, + block_no, + num_blocks, + family_id, + payload_data + b"\x00" * (467 - payload_size), # add padding + UF2_MAGIC_END, + ) ) block_no += 1 diff --git a/ofrak_core/ofrak/core/uimage.py b/ofrak_core/ofrak/core/uimage.py index 330bc23a7..8e2b7ba18 100644 --- a/ofrak_core/ofrak/core/uimage.py +++ b/ofrak_core/ofrak/core/uimage.py @@ -313,12 +313,12 @@ class UImageHeaderAttributesAnalyzer(Analyzer[None, UImageHeader]): outputs = (UImageHeader,) async def analyze(self, resource: Resource, config=None) -> UImageHeader: - tmp = await resource.get_data() - deserializer = BinaryDeserializer( - io.BytesIO(tmp), - endianness=Endianness.BIG_ENDIAN, - word_size=4, - ) + with await resource.get_data_memoryview() as tmp: + deserializer = BinaryDeserializer( + io.BytesIO(tmp), + endianness=Endianness.BIG_ENDIAN, + word_size=4, + ) deserialized = deserializer.unpack_multiple(f"IIIIIIIBBBB{UIMAGE_NAME_LEN}s") ( @@ -363,13 +363,13 @@ class UImageMultiHeaderAttributesAnalyzer(Analyzer[None, UImageMultiHeader]): outputs = (UImageMultiHeader,) async def analyze(self, resource: Resource, config=None) -> UImageMultiHeader: - resource_data = await resource.get_data() - deserializer = BinaryDeserializer( - io.BytesIO(resource_data), - endianness=Endianness.BIG_ENDIAN, - word_size=4, - ) - uimage_multi_header_size = (len(resource_data) - 4) // 4 # Remove trailing null dword + with await resource.get_data_memoryview() as resource_data: + deserializer = BinaryDeserializer( + io.BytesIO(resource_data), + endianness=Endianness.BIG_ENDIAN, + word_size=4, + ) + uimage_multi_header_size = (len(resource_data) - 4) // 4 # Remove trailing null dword deserialized = deserializer.unpack_multiple(f"{uimage_multi_header_size}I") return UImageMultiHeader(deserialized) @@ -568,7 +568,7 @@ class UImagePacker(Packer[None]): targets = (UImage,) async def pack(self, resource: Resource, config=None): - repacked_body_data = b"" + repacked_body_data = bytearray() uimage_view = await resource.view_as(UImage) header = await uimage_view.get_header() if header.get_type() == UImageType.MULTI: @@ -578,9 +578,11 @@ async def pack(self, resource: Resource, config=None): multi_header = await uimage_view.get_multi_header() multiheader_modifier_config = UImageMultiHeaderModifierConfig(image_sizes=image_sizes) await multi_header.resource.run(UImageMultiHeaderModifier, multiheader_modifier_config) - repacked_body_data += await multi_header.resource.get_data() + with await multi_header.resource.get_data_memoryview() as data: + repacked_body_data.extend(data) for uimage_body in await uimage_view.get_bodies(): - repacked_body_data += await uimage_body.resource.get_data() + with await uimage_body.resource.get_data_memoryview() as data: + repacked_body_data.extend(data) # If there are UImageTrailingBytes, get them as well. resource_children = await resource.get_children() @@ -588,7 +590,8 @@ async def pack(self, resource: Resource, config=None): trailing_bytes_r = await resource.get_only_child_as_view( UImageTrailingBytes, ResourceFilter.with_tags(UImageTrailingBytes) ) - repacked_body_data += await trailing_bytes_r.resource.get_data() + with await trailing_bytes_r.resource.get_data_memoryview() as data: + repacked_body_data.extend(data) ih_size = len(repacked_body_data) ih_dcrc = zlib.crc32(repacked_body_data) header_modifier_config = UImageHeaderModifierConfig(ih_size=ih_size, ih_dcrc=ih_dcrc) diff --git a/ofrak_core/ofrak/core/zlib.py b/ofrak_core/ofrak/core/zlib.py index 125926359..7bd7c4b73 100644 --- a/ofrak_core/ofrak/core/zlib.py +++ b/ofrak_core/ofrak/core/zlib.py @@ -27,8 +27,7 @@ class ZlibCompressionLevelAnalyzer(Analyzer[None, ZlibData]): outputs = (ZlibData,) async def analyze(self, resource: Resource, config=None) -> ZlibData: - zlib_data = await resource.get_data(Range(0, 2)) - flevel = zlib_data[-1] + (flevel,) = await resource.get_data(Range(1, 2)) if flevel == 0x01: compression_level = 1 elif flevel == 0x5E: @@ -52,8 +51,8 @@ class ZlibUnpacker(Unpacker[None]): children = (GenericBinary,) async def unpack(self, resource: Resource, config=None): - zlib_data = await resource.get_data() - zlib_uncompressed_data = zlib.decompress(zlib_data) + with await resource.get_data_memoryview() as zlib_data: + zlib_uncompressed_data = zlib.decompress(zlib_data) await resource.create_child( tags=(GenericBinary,), data=zlib_uncompressed_data, @@ -71,10 +70,9 @@ async def pack(self, resource: Resource, config=None): zlib_view = await resource.view_as(ZlibData) compression_level = zlib_view.compression_level zlib_child = await zlib_view.get_child() - zlib_data = await zlib_child.resource.get_data() - zlib_compressed = zlib.compress(zlib_data, compression_level) - - original_zlib_size = await zlib_view.resource.get_data_length() + with await zlib_child.resource.get_data_memoryview() as zlib_data: + zlib_compressed = zlib.compress(zlib_data, compression_level) + original_zlib_size = len(zlib_data) resource.queue_patch(Range(0, original_zlib_size), zlib_compressed) diff --git a/ofrak_core/ofrak/core/zstd.py b/ofrak_core/ofrak/core/zstd.py index 91df1c1b9..a54d9ad83 100644 --- a/ofrak_core/ofrak/core/zstd.py +++ b/ofrak_core/ofrak/core/zstd.py @@ -45,7 +45,8 @@ async def unpack(self, resource: Resource, config: ComponentConfig = None) -> No proc = await asyncio.create_subprocess_exec( *cmd, stdin=asyncio.subprocess.PIPE, stdout=asyncio.subprocess.PIPE ) - result, _ = await proc.communicate(await resource.get_data()) + with await resource.get_data_memoryview() as data: + result, _ = await proc.communicate(data) if proc.returncode: raise CalledProcessError(returncode=proc.returncode, cmd=cmd) @@ -65,7 +66,6 @@ async def pack(self, resource: Resource, config: Optional[ZstdPackerConfig] = No config = ZstdPackerConfig(compression_level=19) zstd_view = await resource.view_as(ZstdData) child_file = await zstd_view.get_child() - uncompressed_data = await child_file.resource.get_data() command = ["zstd", "-T0", f"-{config.compression_level}"] if config.compression_level > 19: @@ -73,7 +73,8 @@ async def pack(self, resource: Resource, config: Optional[ZstdPackerConfig] = No proc = await asyncio.create_subprocess_exec( *command, stdin=asyncio.subprocess.PIPE, stdout=asyncio.subprocess.PIPE ) - result, _ = await proc.communicate(uncompressed_data) + with await child_file.resource.get_data_memoryview() as data: + result, _ = await proc.communicate(data) if proc.returncode: raise CalledProcessError(returncode=proc.returncode, cmd=command) diff --git a/ofrak_core/ofrak/resource.py b/ofrak_core/ofrak/resource.py index 69c651959..1835f4122 100644 --- a/ofrak_core/ofrak/resource.py +++ b/ofrak_core/ofrak/resource.py @@ -167,25 +167,51 @@ def get_model(self) -> MutableResourceModel: """ return self._resource - async def get_data(self, range: Optional[Range] = None) -> bytes: + async def get_data_memoryview( + self, range: Optional[Range] = None, force_readonly: bool = True + ) -> memoryview: """ A resource often represents a chunk of underlying binary data. This method returns the - entire chunk by default; this can be reduced by an optional parameter. + entire chunk by default; this can be reduced by an optional parameter. This method is provided + in order to reduce the number of copy operations in certain scenarios such as passing a buffer + to ctypes. :param range: A range within the resource's data, relative to the resource's data itself (e.g. Range(0, 10) returns the first 10 bytes of the chunk) + :param force_readonly: If True (default), the returned memoryview will be guaranteed to be + readonly. Can be set to False to maybe get an unsafe writable memoryview. This should only + be used for specific optimizations such as avoiding copies when passing data via a ctypes buffer. - :return: The full range or a partial range of this resource's bytes + :return: The full range or a partial range of this resource's data as a memoryview which should + be used as a context manager. The returned memoryview will have itemsize=1. The returned memoryview + should NOT be written to directly even if it is not readonly. """ if self._resource.data_id is None: raise ValueError( "Resource does not have a data_id. Cannot get data from a resource with no data" ) - data = await self._data_service.get_data(self._resource.data_id, range) + memview = await self._data_service.get_data_memoryview(self._resource.data_id, range) if range is None: - range = Range(0, len(data)) + range = Range(0, len(memview)) self._component_context.access_trackers[self._resource.id].data_accessed.add(range) - return data + if force_readonly: + memview_ro = memview.toreadonly() + memview.release() + return memview_ro + return memview + + async def get_data(self, range: Optional[Range] = None) -> bytes: + """ + A resource often represents a chunk of underlying binary data. This method returns the + entire chunk by default; this can be reduced by an optional parameter. + + :param range: A range within the resource's data, relative to the resource's data itself + (e.g. Range(0, 10) returns the first 10 bytes of the chunk) + + :return: The full range or a partial range of this resource's bytes + """ + with await self.get_data_memoryview(range) as buffer: + return buffer.tobytes() async def get_data_length(self) -> int: """ @@ -579,7 +605,8 @@ async def write_to(self, destination: BinaryIO, pack: bool = True): if pack is True: await self.pack_recursively() - destination.write(await self.get_data()) + with await self.get_data_memoryview() as buffer: + destination.write(buffer) async def _analyze_attributes(self, attribute_types: Tuple[Type[ResourceAttributes], ...]): job_context = self._job_context @@ -1442,14 +1469,9 @@ async def flush_data_to_disk(self, path: str, pack: bool = True): if pack is True: await self.pack_recursively() - data = await self.get_data() - if data is not None: + with await self.get_data_memoryview() as buffer: with open(path, "wb") as f: - f.write(data) - else: - # Create empty file - with open(path, "wb") as f: - pass + f.write(buffer) def __repr__(self): properties = [ @@ -1539,9 +1561,11 @@ async def temp_to_disk( delete: bool = True, ) -> AsyncIterator[str]: with tempfile.NamedTemporaryFile( - mode="wb", prefix=prefix, suffix=suffix, dir=dir, delete_on_close=False, delete=delete + mode="w+b", prefix=prefix, suffix=suffix, dir=dir, delete_on_close=False, delete=delete ) as temp: - temp.write(await self.get_data()) + # This cast() shouldn't actually be needed but mypy doesn't correctly + # infer the type of temp + await self.write_to(cast(BinaryIO, temp), pack=False) temp.close() yield temp.name @@ -1729,24 +1753,24 @@ async def _default_summarize_resource(resource: Resource) -> str: if resource._resource.data_id: root_data_range = await resource.get_data_range_within_root() parent_data_range = await resource.get_data_range_within_parent() - data = await resource.get_data() - if len(data) <= 128: - # Convert bytes to string to check .isprintable without doing .decode. Note that - # not all ASCII is printable, so we have to check both decodable and printable - raw_data_str = "".join(map(chr, data)) - if raw_data_str.isascii() and raw_data_str.isprintable(): - data_string = f'data_ascii="{data.decode("ascii")}"' + with await resource.get_data_memoryview() as data: + if len(data) <= 128: + # Convert bytes to string to check .isprintable without doing .decode. Note that + # not all ASCII is printable, so we have to check both decodable and printable + raw_data_str = "".join(map(chr, data)) + if raw_data_str.isascii() and raw_data_str.isprintable(): + data_string = f'data_ascii="{str(data, "ascii")}"' + else: + data_string = f"data_hex={data.hex()}" else: - data_string = f"data_hex={data.hex()}" - else: - sha256 = hashlib.sha256() - sha256.update(data) - data_string = f"data_hash={sha256.hexdigest()[:8]}" - data_info = ( - f", global_offset=({hex(root_data_range.start)}-{hex(root_data_range.end)})" - f", parent_offset=({hex(parent_data_range.start)}-{hex(parent_data_range.end)})" - f", {data_string}" - ) + sha256 = hashlib.sha256() + sha256.update(data) + data_string = f"data_hash={sha256.hexdigest()[:8]}" + data_info = ( + f", global_offset=({hex(root_data_range.start)}-{hex(root_data_range.end)})" + f", parent_offset=({hex(parent_data_range.start)}-{hex(parent_data_range.end)})" + f", {data_string}" + ) else: data_info = "" return ( diff --git a/ofrak_core/ofrak/service/data_service.py b/ofrak_core/ofrak/service/data_service.py index 67845493d..652241418 100644 --- a/ofrak_core/ofrak/service/data_service.py +++ b/ofrak_core/ofrak/service/data_service.py @@ -97,14 +97,28 @@ async def get_range_within_other(self, data_id: DataId, within_data_id: DataId) else: return within_model.range.intersect(model.range).translate(-within_model.range.start) - async def get_data(self, data_id: DataId, data_range: Optional[Range] = None) -> bytes: + def _get_data_memoryview( + self, data_id: DataId, data_range: Optional[Range] = None + ) -> memoryview: model = self._get_by_id(data_id) root = self._get_root_by_id(model.root_id) - if data_range is not None: - translated_range = data_range.translate(model.range.start).intersect(root.model.range) - return root.data[translated_range.start : translated_range.end] - else: - return root.data[model.range.start : model.range.end] + with memoryview(root.data) as memview: + if data_range is not None: + translated_range = data_range.translate(model.range.start).intersect( + root.model.range + ) + return memview[translated_range.start : translated_range.end] + else: + return memview[model.range.start : model.range.end] + + async def get_data_memoryview( + self, data_id: DataId, data_range: Optional[Range] = None + ) -> memoryview: + return self._get_data_memoryview(data_id, data_range) + + async def get_data(self, data_id: DataId, data_range: Optional[Range] = None) -> bytes: + with self._get_data_memoryview(data_id, data_range) as buffer: + return buffer.tobytes() async def apply_patches(self, patches: List[DataPatch]) -> List[DataPatchesResult]: patches_by_root: Dict[DataId, List[DataPatch]] = defaultdict(list) @@ -261,13 +275,11 @@ def _apply_patches_to_root( for affected_range in affected_ranges: results[root_data_id].append(affected_range) - new_root_data = bytearray(root.data) # Apply finalized patches to data and data models for patch_range, data, size_diff in finalized_ordered_patches: - new_root_data[patch_range.start : patch_range.end] = data + root.data[patch_range.start : patch_range.end] = data if size_diff != 0: root.resize_range(patch_range, size_diff) - root.data = bytes(new_root_data) return [ DataPatchesResult(data_id, results_for_id) @@ -326,7 +338,7 @@ def length(self) -> int: def __init__(self, model: DataModel, data: bytes): self.model: DataModel = model - self.data = data + self.data = bytearray(data) self._children: Dict[DataId, DataModel] = dict() # A pair of sorted 2D arrays, where each "point" in the grid is a set of children's data IDs diff --git a/ofrak_core/ofrak/service/data_service_i.py b/ofrak_core/ofrak/service/data_service_i.py index 45fb94b2a..2d0032fd1 100644 --- a/ofrak_core/ofrak/service/data_service_i.py +++ b/ofrak_core/ofrak/service/data_service_i.py @@ -137,6 +137,28 @@ async def get_data(self, data_id: bytes, data_range: Optional[Range] = None) -> """ raise NotImplementedError() + @abstractmethod + async def get_data_memoryview( + self, data_id: bytes, data_range: Optional[Range] = None + ) -> memoryview: + """ + Get the data (or section of data) of a model as a memoryview object. This method is provided + in order to reduce the number of copy operations in certain scenarios such as passing a buffer + to ctypes. + + :param data_id: A unique ID for a data model + :param data_range: An optional range within the model's data to return + + :return: memoryview of data from the model associated with `data_id` - all bytes by default, a + specific slice if `data_range` is provided, and empty bytes if `data_range` is provided but + is outside the modeled data. The returned memoryview will have itemsize=1 and may or may not + be readonly. The returned memoryview should be used as a context manager in order to release it + automatically and should NOT be written to. + + :raises NotFoundError: if `data_id` is not associated with any known model + """ + raise NotImplementedError() + @abstractmethod async def apply_patches( self, diff --git a/ofrak_core/test_ofrak/components/test_data.py b/ofrak_core/test_ofrak/components/test_data.py index 0bfef928d..8d4d42931 100644 --- a/ofrak_core/test_ofrak/components/test_data.py +++ b/ofrak_core/test_ofrak/components/test_data.py @@ -18,6 +18,9 @@ class MockResource: async def get_data(self): return self.data + async def get_data_memoryview(self): + return memoryview(self.data) + @dataclass class MockDataWord(DataWord):