Skip to content

Commit

Permalink
Added logic to NotebookImportExtractor to guess the encoding on initi… (
Browse files Browse the repository at this point in the history
#216)

* Added logic to NotebookImportExtractor to guess the encoding on initial UnicodeDecodeError

Co-authored-by: Mathieu Kniewallner <mathieu.kniewallner@gmail.com>
  • Loading branch information
fpgmaas and mkniewallner authored Nov 22, 2022
1 parent ea7ba57 commit 2823be6
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 14 deletions.
7 changes: 7 additions & 0 deletions deptry/imports/extractors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from dataclasses import dataclass
from pathlib import Path

import chardet


@dataclass
class ImportExtractor(ABC):
Expand All @@ -25,3 +27,8 @@ def _extract_imports_from_ast(tree: ast.AST) -> set[str]:
imported_modules.add(node.module.split(".")[0])

return imported_modules

@staticmethod
def _get_file_encoding(file: Path) -> str:
with open(file, "rb") as f:
return chardet.detect(f.read())["encoding"]
22 changes: 16 additions & 6 deletions deptry/imports/extractors/notebook_import_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import ast
import itertools
import json
import logging
import re
from dataclasses import dataclass
from pathlib import Path
Expand All @@ -17,17 +18,26 @@ class NotebookImportExtractor(ImportExtractor):

def extract_imports(self) -> set[str]:
notebook = self._read_ipynb_file(self.file)
if not notebook:
return set()

cells = self._keep_code_cells(notebook)
import_statements = [self._extract_import_statements_from_cell(cell) for cell in cells]

tree = ast.parse("\n".join(itertools.chain.from_iterable(import_statements)), str(self.file))

return self._extract_imports_from_ast(tree)

@staticmethod
def _read_ipynb_file(path_to_ipynb: Path) -> dict[str, Any]:
with open(path_to_ipynb) as f:
notebook: dict[str, Any] = json.load(f)
@classmethod
def _read_ipynb_file(cls, path_to_ipynb: Path) -> dict[str, Any] | None:
try:
with open(path_to_ipynb) as ipynb_file:
notebook: dict[str, Any] = json.load(ipynb_file)
except UnicodeDecodeError:
try:
with open(path_to_ipynb, encoding=cls._get_file_encoding(path_to_ipynb)) as ipynb_file:
notebook = json.load(ipynb_file, strict=False)
except UnicodeDecodeError:
logging.warning(f"Warning: File {path_to_ipynb} could not be decoded. Skipping...")
return None
return notebook

@staticmethod
Expand Down
8 changes: 0 additions & 8 deletions deptry/imports/extractors/python_import_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,6 @@
import ast
import logging
from dataclasses import dataclass
from pathlib import Path

import chardet

from deptry.imports.extractors.base import ImportExtractor

Expand All @@ -27,8 +24,3 @@ def extract_imports(self) -> set[str]:
return set()

return self._extract_imports_from_ast(tree)

@staticmethod
def _get_file_encoding(file: Path) -> str:
with open(file, "rb") as f:
return chardet.detect(f.read())["encoding"]
42 changes: 42 additions & 0 deletions tests/imports/test_extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,48 @@ def test_import_parser_file_encodings(file_content: str, encoding: str | None, t
assert get_imported_modules_from_file(Path(random_file_name)) == {"foo"}


@pytest.mark.parametrize(
("code_cell_content", "encoding"),
[
(
["import foo", "print('嘉大')"],
"utf-8",
),
(
["import foo", "print('Æ Ç')"],
"iso-8859-15",
),
(
["import foo", "print('嘉大')"],
"utf-16",
),
(
["my_string = '🐺'", "import foo"],
None,
),
],
)
def test_import_parser_file_encodings_ipynb(code_cell_content: list[str], encoding: str | None, tmp_path: Path) -> None:
random_file_name = f"file_{uuid.uuid4()}.ipynb"

with run_within_dir(tmp_path):
with open(random_file_name, "w", encoding=encoding) as f:
file_content = f"""{{
"cells": [
{{
"cell_type": "code",
"metadata": {{}},
"source": [
{", ".join([ f'"{code_line}"' for code_line in code_cell_content])}
]
}}
]}}"""
f.write(file_content)
print(file_content)

assert get_imported_modules_from_file(Path(random_file_name)) == {"foo"}


def test_import_parser_file_encodings_warning(tmp_path: Path, caplog: LogCaptureFixture) -> None:
with run_within_dir(tmp_path):
with open("file1.py", "w", encoding="utf-8") as f:
Expand Down

0 comments on commit 2823be6

Please sign in to comment.