Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions src/django_autotyping/management/commands/generate_stubs.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from django.apps import apps
from django.conf import settings
from django.core.management.base import BaseCommand, CommandParser
from django.core.management.base import BaseCommand, CommandError, CommandParser

from django_autotyping._compat import Unpack
from django_autotyping.app_settings import AutotypingSettings
Expand Down Expand Up @@ -45,8 +45,17 @@ def add_arguments(self, parser: CommandParser) -> None:
)

def handle(self, *args: Any, **options: Unpack[CommandOptions]) -> None:
create_local_django_stubs(options["local_stubs_dir"], stubs_settings.SOURCE_STUBS_DIR)
if not (stubs_dir := options["local_stubs_dir"]).is_dir():
raise CommandError(f"The local stubs directory {stubs_dir} does not exist or is not a directory")

if not (stubs_dir / "django-stubs").is_dir():
self.stdout.write("Copying the 'django-stubs' package from site packages to the local stubs directory")
create_local_django_stubs(options["local_stubs_dir"], stubs_settings.SOURCE_STUBS_DIR)
codemods = gather_codemods(options["ignore"])

django_context = DjangoStubbingContext(apps, settings)
run_codemods(codemods, django_context, stubs_settings)
results = run_codemods(codemods, django_context, stubs_settings)
for stub_file, content in results.items():
self.stdout.write(f"Writing contents to {stub_file}")
target_file = options["local_stubs_dir"] / "django-stubs" / stub_file
target_file.write_text(content, encoding="utf-8")
53 changes: 41 additions & 12 deletions src/django_autotyping/stubbing/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from __future__ import annotations

import os
import shutil
import site
from collections import defaultdict
from concurrent.futures import ProcessPoolExecutor, as_completed
from pathlib import Path

import libcst as cst
Expand All @@ -17,23 +20,50 @@ def run_codemods(
codemods: list[type[StubVisitorBasedCodemod]],
django_context: DjangoStubbingContext,
stubs_settings: StubsGenerationSettings,
) -> None:
) -> dict[str, str]:
"""Given a list of codemods, apply them to the related files.

Returns:
A mapping between the stub file name and the new file content.
"""
django_stubs_dir = stubs_settings.SOURCE_STUBS_DIR or _get_django_stubs_dir()

# From 'codemod -> set[files]' to 'file -> list[codemods]'
# (different codemods could apply to the same file(s))
files_codemods_dct: defaultdict[str, list[type[StubsGenerationSettings]]] = defaultdict(list)
for codemod in codemods:
for stub_file in codemod.STUB_FILES:
context = CodemodContext(
filename=stub_file, scratch={"django_context": django_context, "stubs_settings": stubs_settings}
)
transformer = codemod(context)
source_file = django_stubs_dir / stub_file
target_file = stubs_settings.LOCAL_STUBS_DIR / "django-stubs" / stub_file
files_codemods_dct[stub_file].append(codemod)

with ProcessPoolExecutor(min(len(files_codemods_dct), os.cpu_count() or 1)) as executor:
futures = {
executor.submit(
_run_codemods_on_file, codemods, django_context, stubs_settings, django_stubs_dir / stub_file
): stub_file
for stub_file, codemods in files_codemods_dct.items()
}

return {futures[future]: future.result() for future in as_completed(futures)}


def _run_codemods_on_file(
codemods: list[type[StubVisitorBasedCodemod]],
django_context: DjangoStubbingContext,
stubs_settings: StubsGenerationSettings,
source_file: Path,
) -> str:
input_code = source_file.read_text(encoding="utf-8")
input_module = cst.parse_module(input_code)

for codemod in codemods:
context = CodemodContext(
filename=source_file.name, scratch={"django_context": django_context, "stubs_settings": stubs_settings}
)
transformer = codemod(context)

input_code = source_file.read_text(encoding="utf-8")
input_module = cst.parse_module(input_code)
output_module = transformer.transform_module(input_module)
input_module = transformer.transform_module(input_module)

target_file.write_text(output_module.code, encoding="utf-8")
return input_module.code


def _get_django_stubs_dir() -> Path:
Expand All @@ -49,7 +79,6 @@ def create_local_django_stubs(stubs_dir: Path, source_django_stubs: Path | None

If `source_django_stubs` is not provided, the first entry in site packages will be used.
"""
stubs_dir.mkdir(exist_ok=True)
source_django_stubs = source_django_stubs or _get_django_stubs_dir()
if not (stubs_dir / "django-stubs").is_dir():
shutil.copytree(source_django_stubs, stubs_dir / "django-stubs")
Expand Down
1 change: 1 addition & 0 deletions tests/stubbing/test_stubs.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@

@pytest.fixture
def local_stubs(tmp_path) -> Path:
tmp_path.mkdir(exist_ok=True)
create_local_django_stubs(tmp_path)
return tmp_path

Expand Down