diff --git a/home/import_helpers.py b/home/import_helpers.py index 7a2e80b5..3a4cf307 100644 --- a/home/import_helpers.py +++ b/home/import_helpers.py @@ -1,6 +1,6 @@ # The error messages are processed and parsed into a list of messages we return to the user import csv -from collections.abc import Generator, Iterator +from collections.abc import Iterator from datetime import datetime from io import BytesIO, StringIO from typing import Any @@ -165,21 +165,27 @@ def extract_errors(data: dict[str | int, Any] | list[str]) -> dict[str, str]: return error_message -def fix_rows( - rows: Generator[dict[str, str], None, None] | Iterator[dict[str | Any, Any]] -) -> Iterator[dict[str, str | None]]: +def fix_rows(rows: Iterator[dict[str | Any, Any]]) -> Iterator[dict[str, str | None]]: """ Fix keys for all rows by lowercasing keys and removing whitespace from keys and values """ - for row in rows: - yield fix_row(row) + for index, row in enumerate(rows): + yield fix_row(row, index) -def fix_row(row: dict[str, str]) -> dict[str, str | None]: +def fix_row(row: dict[str, str | None], index: int) -> dict[str, str | None]: """ Fix a single row by lowercasing the key and removing whitespace from the key and value """ try: + if index == 0: + keys = [_normalise_key(k) for k, v in row.items()] + if len(keys) != len(set(keys)): + raise ImportException( + "Invalid format. Please check that there are no duplicate headers.", + row_num=index, + ) + return {_normalise_key(k): _normalise_value(v) for k, v in row.items()} except AttributeError: raise ImportException( @@ -204,11 +210,11 @@ def parse_file( return enumerate(fix_rows(read_rows(file_content)), start=2) -def read_csv(file_content: bytes) -> Iterator[dict[str, str | None]]: +def read_csv(file_content: bytes) -> Iterator[dict[str, Any]]: return csv.DictReader(StringIO(file_content.decode())) -def read_xlsx(file_content: bytes) -> Generator[dict[str, Any], None, None]: +def read_xlsx(file_content: bytes) -> Iterator[dict[str, Any]]: workbook = load_workbook(BytesIO(file_content), read_only=True, data_only=True) worksheet = get_active_sheet(workbook) diff --git a/home/tests/test_assessment_import_export.py b/home/tests/test_assessment_import_export.py index 2a5fce7e..b5210833 100644 --- a/home/tests/test_assessment_import_export.py +++ b/home/tests/test_assessment_import_export.py @@ -560,7 +560,7 @@ def test_extra_columns_xlsx( csv_impexp.import_content_file("assessment_results.csv", purge=False) xlsx_impexp.import_file("extra_columns.xlsx") assert e.value.message == [ - "Invalid format. Please check that all row values " "have headers." + "Invalid format. Please check that all row values have headers." ] def test_mismatched_length_answers(self, csv_impexp: ImportExport) -> None: diff --git a/home/tests/test_content_import_export.py b/home/tests/test_content_import_export.py index d288ed39..5e038fb9 100644 --- a/home/tests/test_content_import_export.py +++ b/home/tests/test_content_import_export.py @@ -1436,6 +1436,28 @@ def test_import_ordered_sets_csv(self, csv_impexp: ImportExport) -> None: ("relationship", "in_a_relationship"), ] + def test_import_ordered_sets_duplicate_header_csv( + self, csv_impexp: ImportExport + ) -> None: + """ + Importing a CSV with duplicate headers should throw an error + """ + set_profile_field_options() + pt, _created = Locale.objects.get_or_create(language_code="pt") + HomePage.add_root(locale=pt, title="Home (pt)", slug="home-pt") + + csv_impexp.import_file("contentpage_required_fields_multi_locale.csv") + content = csv_impexp.read_bytes("ordered_content_duplicate_header.csv") + + with pytest.raises(ImportException) as e: + csv_impexp.import_ordered_sets(content) + + assert e.value.row_num == 0 + + assert e.value.message == [ + "Invalid format. Please check that there are no duplicate headers." + ] + def test_import_ordered_sets_no_profile_fields_csv( self, csv_impexp: ImportExport ) -> None: