Skip to content

Commit

Permalink
Merge pull request #22 from secondlife/signal/detect-type
Browse files Browse the repository at this point in the history
Add support for detecting archives by signature
  • Loading branch information
bennettgoble authored Mar 24, 2023
2 parents 64aefb4 + fd88cf7 commit 9929575
Show file tree
Hide file tree
Showing 10 changed files with 126 additions and 42 deletions.
87 changes: 87 additions & 0 deletions autobuild/archive_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import multiprocessing
import tarfile
import zipfile

class ArchiveType:
GZ = "gz"
BZ2 = "bz2"
ZIP = "zip"
ZST = "zst"


# File signatures used for sniffing archive type
# https://www.garykessler.net/library/file_sigs.html
_ARCHIVE_MAGIC_NUMBERS = {
b"\x1f\x8b\x08": ArchiveType.GZ,
b"\x42\x5a\x68": ArchiveType.BZ2,
b"\x50\x4b\x03\x04": ArchiveType.ZIP,
b"\x28\xb5\x2f\xfd": ArchiveType.ZST,
}

_ARCHIVE_MAGIC_NUMBERS_MAX = max(len(x) for x in _ARCHIVE_MAGIC_NUMBERS)


def _archive_type_from_signature(filename: str):
"""Sniff archive type using file signature"""
with open(filename, "rb") as f:
head = f.read(_ARCHIVE_MAGIC_NUMBERS_MAX)
for magic, f_type in _ARCHIVE_MAGIC_NUMBERS.items():
if head.startswith(magic):
return f_type
return None


def _archive_type_from_extension(filename: str):
if filename.endswith(".tar.gz"):
return ArchiveType.GZ
if filename.endswith(".tar.bz2"):
return ArchiveType.BZ2
if filename.endswith(".tar.zst"):
return ArchiveType.ZST
if filename.endswith(".zip"):
return ArchiveType.ZIP
return None


def detect_archive_type(filename: str):
"""Given a filename, detect its ArchiveType using file extension and signature."""
f_type = _archive_type_from_extension(filename)
if f_type:
return f_type
return _archive_type_from_signature(filename)


def open_archive(filename: str) -> tarfile.TarFile | zipfile.ZipFile:
f_type = detect_archive_type(filename)

if f_type == ArchiveType.ZST:
return ZstdTarFile(filename, "r")

if f_type == ArchiveType.ZIP:
return zipfile.ZipFile(filename, "r")

return tarfile.open(filename, "r")


class ZstdTarFile(tarfile.TarFile):
def __init__(self, name, mode='r', *, level=4, zstd_dict=None, **kwargs):
from pyzstd import CParameter, ZstdFile
zstdoption = None
if mode != 'r' and mode != 'rb':
zstdoption = {CParameter.compressionLevel : level,
CParameter.nbWorkers : multiprocessing.cpu_count(),
CParameter.checksumFlag : 1}
self.zstd_file = ZstdFile(name, mode,
level_or_option=zstdoption,
zstd_dict=zstd_dict)
try:
super().__init__(fileobj=self.zstd_file, mode=mode, **kwargs)
except:
self.zstd_file.close()
raise

def close(self):
try:
super().close()
finally:
self.zstd_file.close()
16 changes: 3 additions & 13 deletions autobuild/autobuild_tool_install.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,14 @@
import os
import pprint
import sys
import tarfile
import urllib.error
import urllib.parse
import urllib.request
import zipfile

from autobuild import autobuild_base, common, configfile
from autobuild.autobuild_tool_source_environment import get_enriched_environment
from autobuild.hash_algorithms import verify_hash
from autobuild import archive_utils

logger = logging.getLogger('autobuild.install')

Expand Down Expand Up @@ -408,7 +407,7 @@ def _install_binary(configured_name, platform, package, config_file, install_dir

def get_metadata_from_package(package_file) -> configfile.MetadataDescription:
try:
with open_archive(package_file) as archive:
with archive_utils.open_archive(package_file) as archive:
f = archive.extractfile(configfile.PACKAGE_METADATA_FILE)
return configfile.MetadataDescription(stream=f)
except (FileNotFoundError, KeyError):
Expand Down Expand Up @@ -442,15 +441,6 @@ def _default_metadata_for_package(package_file: str, package = None):
return metadata


def open_archive(filename: str) -> tarfile.TarFile | zipfile.ZipFile:
if filename.endswith(".tar.zst"):
return common.ZstdTarFile(filename, "r")
elif filename.endswith(".zip"):
return zipfile.ZipFile(filename, "r")
else:
return tarfile.open(filename, "r")


class ExtractPackageResults:
files: list[str]
conflicts: list[str]
Expand All @@ -468,7 +458,7 @@ def raise_conflicts(self):


def extract_package(package_file: str, install_dir: str, dry_run: bool = False) -> ExtractPackageResults:
with open_archive(package_file) as archive:
with archive_utils.open_archive(package_file) as archive:
results = ExtractPackageResults()
for t in archive:
if t.name == configfile.PACKAGE_METADATA_FILE:
Expand Down
5 changes: 2 additions & 3 deletions autobuild/autobuild_tool_package.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@

import getpass
import glob
import hashlib
import json
import logging
import os
Expand All @@ -32,7 +31,7 @@
from collections import UserDict
from zipfile import ZIP_DEFLATED, ZipFile

from autobuild import autobuild_base, common, configfile
from autobuild import autobuild_base, common, configfile, archive_utils
from autobuild.common import AutobuildError

logger = logging.getLogger('autobuild.package')
Expand Down Expand Up @@ -306,7 +305,7 @@ def _create_tarfile(tarfilename, format, build_directory, filelist, results: dic
tfile = tarfile.open(tarfilename, 'w:gz')
elif format == 'tzst':
tarfilename = tarfilename + '.tar.zst'
tfile = common.ZstdTarFile(tarfilename, 'w', level=22)
tfile = archive_utils.ZstdTarFile(tarfilename, 'w', level=22)
else:
raise PackageError("unknown tar archive format: %s" % format)

Expand Down
24 changes: 0 additions & 24 deletions autobuild/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,27 +524,3 @@ def has_cmd(name, subcmd: str = "help") -> bool:
except OSError:
return False
return not p.returncode


class ZstdTarFile(tarfile.TarFile):
def __init__(self, name, mode='r', *, level=4, zstd_dict=None, **kwargs):
from pyzstd import CParameter, ZstdFile
zstdoption = None
if mode != 'r' and mode != 'rb':
zstdoption = {CParameter.compressionLevel : level,
CParameter.nbWorkers : multiprocessing.cpu_count(),
CParameter.checksumFlag : 1}
self.zstd_file = ZstdFile(name, mode,
level_or_option=zstdoption,
zstd_dict=zstd_dict)
try:
super().__init__(fileobj=self.zstd_file, mode=mode, **kwargs)
except:
self.zstd_file.close()
raise

def close(self):
try:
super().close()
finally:
self.zstd_file.close()
Binary file added tests/data/archive.tar.bz2
Binary file not shown.
Binary file added tests/data/archive.tar.gz
Binary file not shown.
Binary file added tests/data/archive.tar.zst
Binary file not shown.
Binary file added tests/data/archive.zip
Binary file not shown.
32 changes: 32 additions & 0 deletions tests/test_filetype.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import shutil
from os import path
from pathlib import Path
from tests.basetest import temp_dir

import pytest
from autobuild import archive_utils


_DATA_DIR = Path(__file__).parent / "data"

_ARCHIVE_TEST_CASES = (
(path.join(_DATA_DIR, "archive.tar.bz2"), archive_utils.ArchiveType.BZ2),
(path.join(_DATA_DIR, "archive.tar.gz"), archive_utils.ArchiveType.GZ),
(path.join(_DATA_DIR, "archive.tar.zst"), archive_utils.ArchiveType.ZST),
(path.join(_DATA_DIR, "archive.zip"), archive_utils.ArchiveType.ZIP),
)


@pytest.mark.parametrize("filename,expected_type", _ARCHIVE_TEST_CASES)
def test_detect_from_extension(filename, expected_type):
f_type = archive_utils.detect_archive_type(filename)
assert f_type == expected_type


@pytest.mark.parametrize("filename,expected_type", _ARCHIVE_TEST_CASES)
def test_detect_from_signature(filename, expected_type):
with temp_dir() as dir:
filename_no_ext = str(Path(dir) / "archive")
shutil.copyfile(filename, filename_no_ext)
f_type = archive_utils.detect_archive_type(filename_no_ext)
assert f_type == expected_type
4 changes: 2 additions & 2 deletions tests/test_package.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from zipfile import ZipFile

import autobuild.autobuild_tool_package as package
from autobuild import common, configfile
from autobuild import common, configfile, archive_utils
from tests.basetest import BaseTest, CaptureStdout, ExpectError, clean_dir, clean_file

# ****************************************************************************
Expand Down Expand Up @@ -76,7 +76,7 @@ def tearDown(self):

def tar_has_expected(self,tar):
if 'tar.zst' in tar:
tarball = common.ZstdTarFile(tar, 'r')
tarball = archive_utils.ZstdTarFile(tar, 'r')
else:
tarball = tarfile.open(tar, 'r')
packaged_files=tarball.getnames()
Expand Down

0 comments on commit 9929575

Please sign in to comment.