diff --git a/deptry/imports/extractors/base.py b/deptry/imports/extractors/base.py index 8ba0ab94..08b7db0f 100644 --- a/deptry/imports/extractors/base.py +++ b/deptry/imports/extractors/base.py @@ -5,6 +5,8 @@ from dataclasses import dataclass from pathlib import Path +import chardet + @dataclass class ImportExtractor(ABC): @@ -25,3 +27,8 @@ def _extract_imports_from_ast(tree: ast.AST) -> set[str]: imported_modules.add(node.module.split(".")[0]) return imported_modules + + @staticmethod + def _get_file_encoding(file: Path) -> str: + with open(file, "rb") as f: + return chardet.detect(f.read())["encoding"] diff --git a/deptry/imports/extractors/notebook_import_extractor.py b/deptry/imports/extractors/notebook_import_extractor.py index c59fbc85..736c2200 100644 --- a/deptry/imports/extractors/notebook_import_extractor.py +++ b/deptry/imports/extractors/notebook_import_extractor.py @@ -3,6 +3,7 @@ import ast import itertools import json +import logging import re from dataclasses import dataclass from pathlib import Path @@ -17,17 +18,26 @@ class NotebookImportExtractor(ImportExtractor): def extract_imports(self) -> set[str]: notebook = self._read_ipynb_file(self.file) + if not notebook: + return set() + cells = self._keep_code_cells(notebook) import_statements = [self._extract_import_statements_from_cell(cell) for cell in cells] - tree = ast.parse("\n".join(itertools.chain.from_iterable(import_statements)), str(self.file)) - return self._extract_imports_from_ast(tree) - @staticmethod - def _read_ipynb_file(path_to_ipynb: Path) -> dict[str, Any]: - with open(path_to_ipynb) as f: - notebook: dict[str, Any] = json.load(f) + @classmethod + def _read_ipynb_file(cls, path_to_ipynb: Path) -> dict[str, Any] | None: + try: + with open(path_to_ipynb) as ipynb_file: + notebook: dict[str, Any] = json.load(ipynb_file) + except UnicodeDecodeError: + try: + with open(path_to_ipynb, encoding=cls._get_file_encoding(path_to_ipynb)) as ipynb_file: + notebook = json.load(ipynb_file, strict=False) + except UnicodeDecodeError: + logging.warning(f"Warning: File {path_to_ipynb} could not be decoded. Skipping...") + return None return notebook @staticmethod diff --git a/deptry/imports/extractors/python_import_extractor.py b/deptry/imports/extractors/python_import_extractor.py index 9ac105c8..5c4dc8f5 100644 --- a/deptry/imports/extractors/python_import_extractor.py +++ b/deptry/imports/extractors/python_import_extractor.py @@ -3,9 +3,6 @@ import ast import logging from dataclasses import dataclass -from pathlib import Path - -import chardet from deptry.imports.extractors.base import ImportExtractor @@ -27,8 +24,3 @@ def extract_imports(self) -> set[str]: return set() return self._extract_imports_from_ast(tree) - - @staticmethod - def _get_file_encoding(file: Path) -> str: - with open(file, "rb") as f: - return chardet.detect(f.read())["encoding"] diff --git a/tests/imports/test_extract.py b/tests/imports/test_extract.py index 0ba3216c..0779072a 100644 --- a/tests/imports/test_extract.py +++ b/tests/imports/test_extract.py @@ -101,6 +101,48 @@ def test_import_parser_file_encodings(file_content: str, encoding: str | None, t assert get_imported_modules_from_file(Path(random_file_name)) == {"foo"} +@pytest.mark.parametrize( + ("code_cell_content", "encoding"), + [ + ( + ["import foo", "print('嘉大')"], + "utf-8", + ), + ( + ["import foo", "print('Æ Ç')"], + "iso-8859-15", + ), + ( + ["import foo", "print('嘉大')"], + "utf-16", + ), + ( + ["my_string = '🐺'", "import foo"], + None, + ), + ], +) +def test_import_parser_file_encodings_ipynb(code_cell_content: list[str], encoding: str | None, tmp_path: Path) -> None: + random_file_name = f"file_{uuid.uuid4()}.ipynb" + + with run_within_dir(tmp_path): + with open(random_file_name, "w", encoding=encoding) as f: + file_content = f"""{{ + "cells": [ + {{ + "cell_type": "code", + "metadata": {{}}, + "source": [ + {", ".join([ f'"{code_line}"' for code_line in code_cell_content])} + ] + }} + ]}}""" + f.write(file_content) + print(file_content) + + assert get_imported_modules_from_file(Path(random_file_name)) == {"foo"} + + def test_import_parser_file_encodings_warning(tmp_path: Path, caplog: LogCaptureFixture) -> None: with run_within_dir(tmp_path): with open("file1.py", "w", encoding="utf-8") as f: