Skip to content

Commit

Permalink
refactor: add kwargs and nitpicky stuff
Browse files Browse the repository at this point in the history
  • Loading branch information
tlambert03 committed Nov 7, 2024
1 parent 3e98088 commit 205f4d6
Show file tree
Hide file tree
Showing 4 changed files with 162 additions and 116 deletions.
39 changes: 25 additions & 14 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,31 +1,21 @@
[build-system]
requires = ["setuptools", "setuptools-scm"]
build-backend = "setuptools.build_meta"
requires = ["setuptools", "wheel"]

[project]
name = "witty"
description = "Well-in-Time Compiler for Cython Modules"
readme = "README.md"
requires-python = ">=3.9"
classifiers = [
"Programming Language :: Python :: 3",
]
classifiers = ["Programming Language :: Python :: 3"]
keywords = []
license = { text = "BSD 3-Clause License" }
authors = [
{ email = "funkej@janelia.hhmi.org", name = "Jan Funke" },
]
authors = [{ email = "funkej@janelia.hhmi.org", name = "Jan Funke" }]
dynamic = ["version"]
dependencies = ["cython", "setuptools; python_version >= '3.12'"]

[project.optional-dependencies]
dev = [
'pytest',
'ruff',
'mypy',
'pdoc',
'pre-commit'
]
dev = ['pytest', 'ruff', 'mypy', 'pdoc', 'pre-commit']

[project.urls]
homepage = "https://github.com/funkelab/witty"
Expand All @@ -34,3 +24,24 @@ repository = "https://github.com/funkelab/witty"
[tool.ruff]
target-version = "py39"
src = ["src"]

[tool.ruff.lint]
select = [
"E", # style errors
"F", # flakes
"W", # warnings
"I", # isort
"UP", # pyupgrade
]

[tool.mypy]
files = "src/**/*.py"
strict = true
disallow_any_generics = false
disallow_subclassing_any = false
show_error_codes = true
pretty = true

[[tool.mypy.overrides]]
module = ["tests.*"]
disallow_untyped_defs = false
10 changes: 8 additions & 2 deletions src/witty/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
from .compile_module import compile_module
from importlib.metadata import PackageNotFoundError, version

try:
__version__ = version("witty")
except PackageNotFoundError:
# package is not installed
__version__ = "unknown"

__version__ = "0.1"
from .compile_module import compile_module

__all__ = ["compile_module"]
229 changes: 129 additions & 100 deletions src/witty/compile_module.py
Original file line number Diff line number Diff line change
@@ -1,101 +1,87 @@
import os
import Cython
from __future__ import annotations

import hashlib
import importlib.util
import os
import sys
from contextlib import contextmanager
from pathlib import Path
from typing import TYPE_CHECKING, Any, Literal

import Cython
from Cython.Build import cythonize
from Cython.Build.Inline import to_unicode, _get_build_extension
from Cython.Build.Inline import build_ext
from Cython.Utils import get_cython_cache_dir
from pathlib import Path

if TYPE_CHECKING:
from collections.abc import Iterator, Sequence
from types import ModuleType

try:
from distutils.core import Extension
from distutils.core import Distribution, Extension # until Python 3.11
except ImportError:
from setuptools import Extension # type: ignore [no-redef]


def load_dynamic(module_name, module_lib):
spec = importlib.util.spec_from_file_location(module_name, module_lib)
module = importlib.util.module_from_spec(spec)
sys.modules[module_name] = module
spec.loader.exec_module(module)
return sys.modules[module_name]
from setuptools import Distribution, Extension


def compile_module(
source_pyx,
source_files=None,
include_dirs=None,
library_dirs=None,
language="c",
extra_compile_args=None,
extra_link_args=None,
name=None,
force_rebuild=False,
quiet=False,
):
source_pyx: str,
*,
source_files: Sequence[Path | str] = (),
include_dirs: Sequence[Path | str] = (".",),
library_dirs: Sequence[Path | str] = (),
language: Literal["c", "c++"] | None = None,
extra_compile_args: list[str] | None = None,
extra_link_args: list[str] | None = None,
name: str = "_witty_module",
force_rebuild: bool = False,
quiet: bool = False,
**extension_kwargs: Any,
) -> ModuleType:
"""Compile a Cython module given as a PYX source string.
The module will be stored in Cython's cache directory. Called with the same
``source_pyx``, the cached module will be returned.
Args:
source_pyx (``str``):
The PYX source code.
source_files (list of ``Path``s, optional):
Additional source files the PYX code depends on. Changes to those
files will trigger re-compilation of the module.
include_dirs (list of ``Path``s, optional):
library_dirs (list of ``Path``s, optional):
language (``str``, optional):
extra_compile_args (list of ``str``, optional):
extra_link_args (list of ``str``, optional):
Arguments to forward to the Cython extension.
name (``str``, optional):
The base-name of the module file. Defaults to ``_witty_module``.
force_rebuild (``bool``, optional):
Force a rebuild even if a module with that name/hash already
exists.
quiet (``bool``, optional):
Supress output except errors and warnings.
Returns:
The module will be stored in
[Cython's cache directory](https://cython.readthedocs.io/en/latest/src/userguide/source_files_and_compilation.html#cython-cache).
Called with the same `source_pyx`, the cached module will be returned.
Parameters
----------
source_pyx : str
The PYX source code.
source_files : list of Path, optional
Additional source files the PYX code depends on. Changes to these
files will trigger re-compilation of the module.
include_dirs : list of Path, optional
List of directories to search for C/C++ header files (in Unix
form for portability).
library_dirs : list of Path, optional
List of directories to search for C/C++ libraries at link time.
language : str, optional
Extension language (i.e., "c", "c++", "objc"). Will be detected
from the source extensions if not provided.
extra_compile_args : list of str, optional
Extra platform- and compiler-specific information to use when
compiling the source files in 'sources'. This is typically a
list of command-line arguments for platforms and compilers where
"command line" makes sense.
extra_link_args : list of str, optional
Extra platform- and compiler-specific information to use when
linking object files to create the extension (or a new static
Python interpreter). Has a similar interpretation as for 'extra_compile_args'.
name : str, optional
The base name of the module file. Defaults to "_witty_module".
force_rebuild : bool, optional
Force a rebuild even if a module with the same name/hash already exists.
quiet : bool, optional
Suppress output except for errors and warnings.
extension_kwargs : dict, optional
Additional keyword arguments passed to the distutils `Extension` constructor.
Returns
-------
ModuleType
The compiled module.
"""

if source_files is None:
source_files = []
if include_dirs is None:
include_dirs = ["."]
if library_dirs is None:
library_dirs = []
if name is None:
name = "_witty_module"

source_pyx = to_unicode(source_pyx)
sources = [source_pyx]

for source_file in source_files:
sources.append(open(source_file, "r").read())

source_hashes = [
hashlib.md5(source.encode("utf-8")).hexdigest() for source in sources
]
source_key = (source_hashes, sys.version_info, sys.executable, Cython.__version__)
module_hash = hashlib.md5(str(source_key).encode("utf-8")).hexdigest()
module_hash = _hash_sources(source_pyx, source_files)
module_name = name + "_" + module_hash

# already loaded?
Expand All @@ -107,61 +93,104 @@ def compile_module(
module_dir = Path(get_cython_cache_dir()) / "witty"
module_pyx = (module_dir / module_name).with_suffix(".pyx")
module_lib = (module_dir / module_name).with_suffix(module_ext)
module_lock = (module_dir / module_name).with_suffix(".lock")

if not quiet:
print(f"Compiling {module_name} into {module_lib}...")

module_dir.mkdir(parents=True, exist_ok=True)

# make sure the same module is not build concurrently
with open(module_lock, "w") as lock_f:
lock_file(lock_f)

with _module_locked(module_pyx):
# already compiled?
if module_lib.is_file() and not force_rebuild:
if not quiet:
print(f"Reusing already compiled module from {module_lib}")
return load_dynamic(module_name, module_lib)
return _load_dynamic(module_name, module_lib)

# create pyx file
with open(module_pyx, "w") as f:
f.write(source_pyx)
module_pyx.write_text(source_pyx)

extension = Extension(
module_name,
sources=[str(module_pyx)],
include_dirs=include_dirs,
library_dirs=library_dirs,
include_dirs=[str(x) for x in include_dirs],
library_dirs=[str(x) for x in library_dirs],
language=language,
extra_compile_args=extra_compile_args,
extra_link_args=extra_link_args,
**(extension_kwargs or {}),
)

build_extension.extensions = cythonize(
[extension], compiler_directives={"language_level": "3"}, quiet=quiet
[extension],
compiler_directives={"language_level": "3"},
quiet=quiet,
)
build_extension.build_temp = str(module_dir)
build_extension.build_lib = str(module_dir)
build_extension.run()

return load_dynamic(module_name, module_lib)
return _load_dynamic(module_name, module_lib)


def _load_dynamic(module_name: str, module_path: Path) -> ModuleType:
"""Dynamically load a module from a path."""
spec = importlib.util.spec_from_file_location(module_name, module_path)
if not spec or not spec.loader:
raise ImportError(f"Failed to load module {module_name} from {module_path}")
module = importlib.util.module_from_spec(spec)
sys.modules[module_name] = module
spec.loader.exec_module(module)
return sys.modules[module_name]


def _hash_sources(source_pyx: str, source_files: Sequence[Path | str] = ()) -> str:
"""Generate a hash key for a `source_pyx` along with other source file paths."""
sources = [source_pyx] + [Path(source).read_text() for source in source_files]
hashes = [hashlib.md5(source.encode("utf-8")).hexdigest() for source in sources]
source_key = (hashes, sys.version_info, sys.executable, Cython.__version__)
return hashlib.md5(str(source_key).encode("utf-8")).hexdigest()


@contextmanager
def _module_locked(module_path: Path) -> Iterator[None]:
"""Temporarily lock a module file to prevent concurrent compilation."""
module_lock_file = module_path.with_suffix(".lock")
with open(module_lock_file, "w") as lock_fd:
_lock_file(lock_fd)
try:
yield
finally:
_unlock_file(lock_fd)


def _get_build_extension() -> build_ext:
# same as `cythonize` Build.Inline._get_build_extension
# vendored to avoid using a private API
dist = Distribution()
# Ensure the build respects distutils configuration by parsing
# the configuration files
config_files = dist.find_config_files()
dist.parse_config_files(config_files)
build_extension = build_ext(dist)
build_extension.finalize_options()
return build_extension


if os.name == "nt":
import msvcrt

def lock_file(file):
msvcrt.locking(file.fileno(), msvcrt.LK_LOCK, os.path.getsize(file.name))
def _lock_file(file: Any) -> None:
msvcrt.locking(file.fileno(), msvcrt.LK_LOCK, os.path.getsize(file.name)) # type: ignore

def unlock_file(file):
msvcrt.locking(file.fileno(), msvcrt.LK_UNLCK, os.path.getsize(file.name))
def _unlock_file(file: Any) -> None:
msvcrt.locking(file.fileno(), msvcrt.LK_UNLCK, os.path.getsize(file.name)) # type: ignore

else:
import fcntl

def lock_file(file):
def _lock_file(file: Any) -> None:
fcntl.lockf(file, fcntl.LOCK_EX)

def unlock_file(file):
def _unlock_file(file: Any) -> None:
fcntl.lockf(file, fcntl.LOCK_UN)
Empty file added tests/__init__.py
Empty file.

0 comments on commit 205f4d6

Please sign in to comment.