diff --git a/src/django_autotyping/management/commands/generate_stubs.py b/src/django_autotyping/management/commands/generate_stubs.py index 5e1b2e4..9d7a60f 100644 --- a/src/django_autotyping/management/commands/generate_stubs.py +++ b/src/django_autotyping/management/commands/generate_stubs.py @@ -5,7 +5,7 @@ from django.apps import apps from django.conf import settings -from django.core.management.base import BaseCommand, CommandParser +from django.core.management.base import BaseCommand, CommandError, CommandParser from django_autotyping._compat import Unpack from django_autotyping.app_settings import AutotypingSettings @@ -45,8 +45,17 @@ def add_arguments(self, parser: CommandParser) -> None: ) def handle(self, *args: Any, **options: Unpack[CommandOptions]) -> None: - create_local_django_stubs(options["local_stubs_dir"], stubs_settings.SOURCE_STUBS_DIR) + if not (stubs_dir := options["local_stubs_dir"]).is_dir(): + raise CommandError(f"The local stubs directory {stubs_dir} does not exist or is not a directory") + + if not (stubs_dir / "django-stubs").is_dir(): + self.stdout.write("Copying the 'django-stubs' package from site packages to the local stubs directory") + create_local_django_stubs(options["local_stubs_dir"], stubs_settings.SOURCE_STUBS_DIR) codemods = gather_codemods(options["ignore"]) django_context = DjangoStubbingContext(apps, settings) - run_codemods(codemods, django_context, stubs_settings) + results = run_codemods(codemods, django_context, stubs_settings) + for stub_file, content in results.items(): + self.stdout.write(f"Writing contents to {stub_file}") + target_file = options["local_stubs_dir"] / "django-stubs" / stub_file + target_file.write_text(content, encoding="utf-8") diff --git a/src/django_autotyping/stubbing/__init__.py b/src/django_autotyping/stubbing/__init__.py index 4016e80..3e14c4d 100644 --- a/src/django_autotyping/stubbing/__init__.py +++ b/src/django_autotyping/stubbing/__init__.py @@ -1,7 +1,10 @@ from __future__ import annotations +import os import shutil import site +from collections import defaultdict +from concurrent.futures import ProcessPoolExecutor, as_completed from pathlib import Path import libcst as cst @@ -17,23 +20,50 @@ def run_codemods( codemods: list[type[StubVisitorBasedCodemod]], django_context: DjangoStubbingContext, stubs_settings: StubsGenerationSettings, -) -> None: +) -> dict[str, str]: + """Given a list of codemods, apply them to the related files. + + Returns: + A mapping between the stub file name and the new file content. + """ django_stubs_dir = stubs_settings.SOURCE_STUBS_DIR or _get_django_stubs_dir() + # From 'codemod -> set[files]' to 'file -> list[codemods]' + # (different codemods could apply to the same file(s)) + files_codemods_dct: defaultdict[str, list[type[StubsGenerationSettings]]] = defaultdict(list) for codemod in codemods: for stub_file in codemod.STUB_FILES: - context = CodemodContext( - filename=stub_file, scratch={"django_context": django_context, "stubs_settings": stubs_settings} - ) - transformer = codemod(context) - source_file = django_stubs_dir / stub_file - target_file = stubs_settings.LOCAL_STUBS_DIR / "django-stubs" / stub_file + files_codemods_dct[stub_file].append(codemod) + + with ProcessPoolExecutor(min(len(files_codemods_dct), os.cpu_count() or 1)) as executor: + futures = { + executor.submit( + _run_codemods_on_file, codemods, django_context, stubs_settings, django_stubs_dir / stub_file + ): stub_file + for stub_file, codemods in files_codemods_dct.items() + } + + return {futures[future]: future.result() for future in as_completed(futures)} + + +def _run_codemods_on_file( + codemods: list[type[StubVisitorBasedCodemod]], + django_context: DjangoStubbingContext, + stubs_settings: StubsGenerationSettings, + source_file: Path, +) -> str: + input_code = source_file.read_text(encoding="utf-8") + input_module = cst.parse_module(input_code) + + for codemod in codemods: + context = CodemodContext( + filename=source_file.name, scratch={"django_context": django_context, "stubs_settings": stubs_settings} + ) + transformer = codemod(context) - input_code = source_file.read_text(encoding="utf-8") - input_module = cst.parse_module(input_code) - output_module = transformer.transform_module(input_module) + input_module = transformer.transform_module(input_module) - target_file.write_text(output_module.code, encoding="utf-8") + return input_module.code def _get_django_stubs_dir() -> Path: @@ -49,7 +79,6 @@ def create_local_django_stubs(stubs_dir: Path, source_django_stubs: Path | None If `source_django_stubs` is not provided, the first entry in site packages will be used. """ - stubs_dir.mkdir(exist_ok=True) source_django_stubs = source_django_stubs or _get_django_stubs_dir() if not (stubs_dir / "django-stubs").is_dir(): shutil.copytree(source_django_stubs, stubs_dir / "django-stubs") diff --git a/tests/stubbing/test_stubs.py b/tests/stubbing/test_stubs.py index 4b360dc..bd33aaf 100644 --- a/tests/stubbing/test_stubs.py +++ b/tests/stubbing/test_stubs.py @@ -33,6 +33,7 @@ @pytest.fixture def local_stubs(tmp_path) -> Path: + tmp_path.mkdir(exist_ok=True) create_local_django_stubs(tmp_path) return tmp_path