From c51ed5fa5d7e5af2983841f7ec3e870806e0bb25 Mon Sep 17 00:00:00 2001 From: Thierry de Pauw Date: Sat, 2 Dec 2023 01:29:03 +0100 Subject: [PATCH] Move join(input_dir, file) from cli to filelister (#10) --- DMARCReporting/cli.py | 5 +---- DMARCReporting/filelister.py | 3 +-- tests/test_acceptance.py | 10 +++++----- tests/test_filelister.py | 17 ++++++++++++----- 4 files changed, 19 insertions(+), 16 deletions(-) diff --git a/DMARCReporting/cli.py b/DMARCReporting/cli.py index 31ca8d1..bcb627c 100644 --- a/DMARCReporting/cli.py +++ b/DMARCReporting/cli.py @@ -1,5 +1,3 @@ - -from os.path import join import io from DMARCReporting.decompressor import DecompressorFactory @@ -17,8 +15,7 @@ def execute(self, input_dir, csv_output_file=None, show_all=False): all_data = [] for file in files: - file_path = join(input_dir, file) - report = DecompressorFactory.create(file_path).decompress(file_path) + report = DecompressorFactory.create(file).decompress(file) data = parser.parse(io.BytesIO(report), include_all=show_all) all_data += [[*row, file] for row in data] diff --git a/DMARCReporting/filelister.py b/DMARCReporting/filelister.py index aad0250..835ad06 100644 --- a/DMARCReporting/filelister.py +++ b/DMARCReporting/filelister.py @@ -1,4 +1,3 @@ - from os import listdir from os.path import isfile from os.path import join @@ -8,7 +7,7 @@ class FileLister(): def list(self, input_dir): return sorted([ - f for f in listdir(input_dir) if self._is_zip_or_gz_file(input_dir, f) + join(input_dir, f) for f in listdir(input_dir) if self._is_zip_or_gz_file(input_dir, f) ]) def _is_zip_or_gz_file(self, input_dir, file_name): diff --git a/tests/test_acceptance.py b/tests/test_acceptance.py index eb10880..2261b4a 100644 --- a/tests/test_acceptance.py +++ b/tests/test_acceptance.py @@ -23,11 +23,11 @@ def test_render(gethostbyaddr_mock): expected = ( "Source IP Source Host Payload From (From:) Envelop From (MAIL FROM) DMARC DKIM Align DKIM Auth SPF Align SPF Auth Count File\n" # noqa E501 - "------------- ------------------------------ ---------------------- -------------------------- ------- ------------ ----------- ----------- ---------- ------- --------------\n" # noqa E501 - "80.96.161.193 Unknown host bellous.com bellous.com none pass pass fail fail 3 report.xml.gz\n" # noqa E501 - "208.90.221.45 208-90-221-45.static.flhsi.com bellous.com calendar.yambo.com none pass pass fail pass 32 report.xml.gz\n" # noqa E501 - "80.96.161.193 Unknown host disicious.com disicious.com none pass pass fail fail 3 report.xml.zip\n" # noqa E501 - "208.90.221.45 208-90-221-45.static.flhsi.com disicious.com calendar.trumbee.com none pass pass fail pass 32 report.xml.zip\n" # noqa E501 + "------------- ------------------------------ ---------------------- -------------------------- ------- ------------ ----------- ----------- ---------- ------- ------------------------\n" # noqa E501 + "80.96.161.193 Unknown host bellous.com bellous.com none pass pass fail fail 3 ./samples/report.xml.gz\n" # noqa E501 + "208.90.221.45 208-90-221-45.static.flhsi.com bellous.com calendar.yambo.com none pass pass fail pass 32 ./samples/report.xml.gz\n" # noqa E501 + "80.96.161.193 Unknown host disicious.com disicious.com none pass pass fail fail 3 ./samples/report.xml.zip\n" # noqa E501 + "208.90.221.45 208-90-221-45.static.flhsi.com disicious.com calendar.trumbee.com none pass pass fail pass 32 ./samples/report.xml.zip\n" # noqa E501 ) actual = output.getvalue() diff --git a/tests/test_filelister.py b/tests/test_filelister.py index 06c28b6..589acdc 100644 --- a/tests/test_filelister.py +++ b/tests/test_filelister.py @@ -1,9 +1,11 @@ - import unittest -from .context import DMARCReporting # noqa F401 -from DMARCReporting.filelister import FileLister import tempfile import os +from os.path import join + +from .context import DMARCReporting # noqa F401 +from DMARCReporting.filelister import FileLister + from parameterized import parameterized @@ -20,9 +22,14 @@ class TestFileLister(unittest.TestCase): def test_file_lister(self, name, filesList, expected): def listerFunction(dirName): return FileLister().list(dirName) - actual = self.list_files_with_function(filesList, listerFunction) + with tempfile.TemporaryDirectory() as testDir: + [self.create_test_file(testDir, fileName) for fileName in filesList] + actual = listerFunction(testDir) + # actual = self.list_files_with_function(filesList, listerFunction) + + expected = [join(testDir, f) for f in expected] - self.assertListEqual(expected, actual) + self.assertListEqual(expected, actual) def create_test_file(self, path, filename): with open(os.path.join(path, filename), 'w'):