Skip to content

Commit

Permalink
Merge pull request #8 from mbway/optimisations
Browse files Browse the repository at this point in the history
Optimisations
  • Loading branch information
messense authored Nov 18, 2024
2 parents dc44f36 + ae3add5 commit 49f973b
Show file tree
Hide file tree
Showing 8 changed files with 196 additions and 8 deletions.
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ repos:
hooks:
- id: ruff-format
- id: ruff
args: [ --fix ]
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.10.0
hooks:
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ ignore = [
"RET505", # superfluous-else-return
"S101", # assert
"S301", # suspicious-pickle-usage
"S311", # suspicious-non-cryptographic-random-usage
"S324", # hashlib-insecure-hash-function
"S603", # subprocess-without-shell-equals-true
"S607", # start-process-with-partial-path
Expand Down
17 changes: 13 additions & 4 deletions src/maturin_import_hook/_resolve_project.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ def load(path: Path) -> "_TomlFile":
data = tomllib.load(f)
return _TomlFile(path, data)

@staticmethod
def from_string(path: Path, data_str: str) -> "_TomlFile":
return _TomlFile(path, tomllib.loads(data_str))

def get_value_or_default(self, keys: list[str], required_type: type[_T], default: _T) -> _T:
value = self.get_value(keys, required_type)
return default if value is None else value
Expand Down Expand Up @@ -57,10 +61,12 @@ def get_value(self, keys: list[str], required_type: type[_T]) -> Optional[_T]:
def find_cargo_manifest(project_dir: Path) -> Optional[Path]:
pyproject_path = project_dir / "pyproject.toml"
if pyproject_path.exists():
pyproject = _TomlFile.load(pyproject_path)
relative_manifest_path = pyproject.get_value(["tool", "maturin", "manifest-path"], str)
if relative_manifest_path is not None:
return project_dir / relative_manifest_path
pyproject_data = pyproject_path.read_text()
if "manifest-path" in pyproject_data:
pyproject = _TomlFile.from_string(pyproject_path, pyproject_data)
relative_manifest_path = pyproject.get_value(["tool", "maturin", "manifest-path"], str)
if relative_manifest_path is not None:
return project_dir / relative_manifest_path

manifest_path = project_dir / "Cargo.toml"
if manifest_path.exists():
Expand All @@ -80,6 +86,9 @@ class ProjectResolver:
def __init__(self) -> None:
self._resolved_project_cache: dict[Path, Optional[MaturinProject]] = {}

def clear_cache(self) -> None:
self._resolved_project_cache.clear()

def resolve(self, project_dir: Path) -> Optional["MaturinProject"]:
if project_dir not in self._resolved_project_cache:
resolved = None
Expand Down
21 changes: 20 additions & 1 deletion src/maturin_import_hook/project_importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import urllib.request
from abc import ABC, abstractmethod
from collections.abc import Iterator, Sequence
from functools import lru_cache
from importlib.machinery import ExtensionFileLoader, ModuleSpec, PathFinder
from pathlib import Path
from types import ModuleType
Expand Down Expand Up @@ -102,6 +103,12 @@ def find_maturin(self) -> Path:
self._maturin_path = find_maturin((1, 5, 0), (2, 0, 0))
return self._maturin_path

def invalidate_caches(self) -> None:
"""called by importlib.invalidate_caches()"""
logger.info("clearing cache")
self._resolver.clear_cache()
_find_maturin_project_above.cache_clear()

def find_spec(
self,
fullname: str,
Expand Down Expand Up @@ -373,17 +380,29 @@ def _is_editable_installed_package(project_dir: Path, package_name: str) -> bool
return False


@lru_cache(maxsize=4096)
def _find_maturin_project_above(path: Path) -> Optional[Path]:
for search_path in itertools.chain((path,), path.parents):
if is_maybe_maturin_project(search_path):
return search_path
return None


def _find_dist_info_path(directory: Path, package_name: str) -> Optional[Path]:
try:
names = os.listdir(directory)
except FileNotFoundError:
return None
for name in names:
if name.startswith(package_name) and name.endswith(".dist-info"):
return Path(directory / name)
return None


def _load_dist_info(
path: Path, package_name: str, *, require_project_target: bool = True
) -> tuple[Optional[Path], bool]:
dist_info_path = next(path.glob(f"{package_name}-*.dist-info"), None)
dist_info_path = _find_dist_info_path(path, package_name)
if dist_info_path is None:
return None, False
try:
Expand Down
13 changes: 13 additions & 0 deletions tests/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,19 @@ Connect to the debugger, eg [using vscode](https://code.visualstudio.com/docs/py
Note: set `CLEAR_WORKSPACE = False` in `common.py` if you want to prevent the temporary files generated during the test
from being cleared.

### Benchmarking

The `create_benchmark_data.py` script creates a directory with many python packages to represent a worst case scenario.
Run the script then run `venv/bin/python run.py` from the created directory.

One way of obtaining profiling information is to run:

```sh
venv/bin/python -m cProfile -o profile.prof run.py
pyprof2calltree -i profile.prof -o profile.log
kcachegrind profile.log
```

### Caching

sccache is a tool for caching build artifacts to speed up compilation. Unfortunately, it is currently useless for these
Expand Down
Empty file added tests/__init__.py
Empty file.
131 changes: 131 additions & 0 deletions tests/create_benchmark_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
import argparse
import logging
import random
import string
import sys
import textwrap
from dataclasses import dataclass
from pathlib import Path

from runner import VirtualEnv

script_dir = Path(__file__).resolve().parent
repo_root = script_dir.parent

log = logging.getLogger("runner")
logging.basicConfig(format="[%(name)s] [%(levelname)s] %(message)s", level=logging.DEBUG)


@dataclass
class BenchmarkConfig:
seed: int
filename_length: int
depth: int
num_python_editable_packages: int

@staticmethod
def default() -> "BenchmarkConfig":
return BenchmarkConfig(
seed=0,
filename_length=10,
depth=10,
num_python_editable_packages=100,
)


def random_name(rng: random.Random, length: int) -> str:
return "".join(rng.choices(string.ascii_lowercase, k=length))


def random_path(rng: random.Random, root: Path, depth: int, name_length: int) -> Path:
path = root
for _ in range(depth):
path = path / random_name(rng, name_length)
return path


def create_python_package(root: Path) -> tuple[str, Path]:
root.mkdir(parents=True, exist_ok=False)
src_dir = root / "src" / root.name
src_dir.mkdir(parents=True)
(src_dir / "__init__.py").write_text(
textwrap.dedent(f"""\
def get_name():
return "{root.name}"
""")
)
(root / "pyproject.toml").write_text(
textwrap.dedent(f"""\
[project]
name = "{root.name}"
version = "0.1.0"
[tool.setuptools.packages.find]
where = ["src"]
[build-system]
requires = ["setuptools", "wheel"]
build-backend = "setuptools.build_meta"
""")
)
return root.name, src_dir


def create_benchmark_environment(root: Path, config: BenchmarkConfig) -> None:
rng = random.Random(config.seed)

log.info("creating benchmark environment at %s", root)
root.mkdir(parents=True, exist_ok=False)
venv = VirtualEnv.create(root / "venv", Path(sys.executable))

venv.install_editable_package(repo_root)

python_package_names = []
python_package_paths = []

packages_root = random_path(rng, root, config.depth, config.filename_length)
name, src_dir = create_python_package(packages_root)
python_package_names.append(name)
python_package_paths.append(src_dir)

for _ in range(config.num_python_editable_packages):
path = random_path(rng, packages_root, config.depth, config.filename_length)
name, src_dir = create_python_package(path)
python_package_names.append(name)
python_package_paths.append(src_dir)

python_package_paths_str = ", ".join(f'"{path.parent}"' for path in python_package_paths)
import_python_packages = "\n".join(f"import {name}" for name in python_package_names)
(root / "run.py").write_text(f"""\
import time
import logging
import sys
import maturin_import_hook
sys.path.extend([{python_package_paths_str}])
# logging.basicConfig(format='%(asctime)s %(name)s [%(levelname)s] %(message)s', level=logging.DEBUG)
# maturin_import_hook.reset_logger()
maturin_import_hook.install()
start = time.perf_counter()
{import_python_packages}
end = time.perf_counter()
print(f'took {{end - start:.6f}}s')
""")


def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("root", type=Path, help="the location to write the benchmark data to")
args = parser.parse_args()

config = BenchmarkConfig.default()
create_benchmark_environment(args.root, config)


if __name__ == "__main__":
main()
20 changes: 17 additions & 3 deletions tests/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def _pip_install_command(interpreter_path: Path) -> list[str]:


def _create_test_venv(python: Path, venv_dir: Path) -> VirtualEnv:
venv = VirtualEnv.new(venv_dir, python)
venv = VirtualEnv.create(venv_dir, python)
log.info("installing test requirements into virtualenv")
proc = subprocess.run(
[
Expand Down Expand Up @@ -156,13 +156,22 @@ def _create_virtual_env_command(interpreter_path: Path, venv_path: Path) -> list
return [str(interpreter_path), "-m", "venv", str(venv_path)]


def _install_into_virtual_env_command(interpreter_path: Path, package_path: Path) -> list[str]:
if shutil.which("uv") is not None:
log.info("using uv to install package as editable")
return ["uv", "pip", "install", "--python", str(interpreter_path), "--editable", str(package_path)]
else:
log.info("using pip to install package as editable")
return [str(interpreter_path), "-m", "pip", "install", "--editable", str(package_path)]


class VirtualEnv:
def __init__(self, root: Path) -> None:
self._root = root.resolve()
self._is_windows = platform.system() == "Windows"

@staticmethod
def new(root: Path, interpreter_path: Path) -> VirtualEnv:
def create(root: Path, interpreter_path: Path) -> VirtualEnv:
if root.exists():
log.info("removing virtualenv at %s", root)
shutil.rmtree(root)
Expand Down Expand Up @@ -194,6 +203,11 @@ def interpreter_path(self) -> Path:
assert interpreter.exists()
return interpreter

def install_editable_package(self, package_path: Path) -> None:
cmd = _install_into_virtual_env_command(self.interpreter_path, package_path)
proc = subprocess.run(cmd, capture_output=True, check=True)
log.debug("%s", proc.stdout.decode())

def activate(self, env: dict[str, str]) -> None:
"""set the environment as-if venv/bin/activate was run"""
path = env.get("PATH", "").split(os.pathsep)
Expand Down Expand Up @@ -254,7 +268,7 @@ def main() -> None:
parser.add_argument(
"--name",
default="Tests",
help="the name for the suite of tests this run (use to distinguish between OS/python version)",
help="the name to assign for the suite of tests this run (use to distinguish between OS/python version)",
)

parser.add_argument(
Expand Down

0 comments on commit 49f973b

Please sign in to comment.