diff --git a/py7zr/py7zr.py b/py7zr/py7zr.py index 6c91451a..d0f6741a 100644 --- a/py7zr/py7zr.py +++ b/py7zr/py7zr.py @@ -37,6 +37,7 @@ import sys import time from multiprocessing import Process +from shutil import ReadError from threading import Thread from typing import IO, Any, BinaryIO, Collection, Dict, List, Optional, Tuple, Type, Union @@ -1210,6 +1211,8 @@ def unpack_7zarchive(archive, path, extra=None): """ Function for registering with shutil.register_unpack_format(). """ + if not is_7zfile(archive): + raise ReadError(f"{archive} is not a 7zip file.") with SevenZipFile(archive) as arc: arc.extractall(path) diff --git a/tests/conftest.py b/tests/conftest.py index 2692c0b8..23960a67 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,9 +1,18 @@ # Configuration for pytest. # Thanks to Guilherme Salgado. +import shutil + import cpuinfo import pytest +from py7zr import unpack_7zarchive + + +@pytest.fixture(scope="session") +def register_shutil_unpack_format(): + shutil.register_unpack_format("7zip", [".7z"], unpack_7zarchive) + def pytest_benchmark_update_json(config, benchmarks, output_json): """Calculate compression/decompression speed and add as extra_info""" diff --git a/tests/test_extract.py b/tests/test_extract.py index 289b56fc..50523e67 100644 --- a/tests/test_extract.py +++ b/tests/test_extract.py @@ -7,12 +7,12 @@ import shutil import subprocess import sys +import tempfile from datetime import datetime import pytest import py7zr -from py7zr import unpack_7zarchive from py7zr.exceptions import CrcError, UnsupportedCompressionMethodError from py7zr.helpers import UTC @@ -306,8 +306,7 @@ def test_zerosize_mem(): @pytest.mark.api -def test_register_unpack_archive(tmp_path): - shutil.register_unpack_format("7zip", [".7z"], unpack_7zarchive) +def test_register_unpack_archive(register_shutil_unpack_format, tmp_path): shutil.unpack_archive(str(testdata_path.joinpath("test_1.7z")), str(tmp_path)) target = tmp_path.joinpath("setup.cfg") expected_mode = 33188 @@ -326,6 +325,12 @@ def test_register_unpack_archive(tmp_path): assert m.digest() == binascii.unhexlify("b0385e71d6a07eb692f5fb9798e9d33aaf87be7dfff936fd2473eab2a593d4fd") +@pytest.mark.api +def test_register_unpack_archive_error(register_shutil_unpack_format, tmp_path): + with tempfile.NamedTemporaryFile(suffix=".7z") as f, pytest.raises(shutil.ReadError): + shutil.unpack_archive(f.name, str(tmp_path)) + + @pytest.mark.files def test_skip(): archive = py7zr.SevenZipFile(testdata_path.joinpath("test_1.7z").open(mode="rb"))