Skip to content

Commit

Permalink
Add type hints to init().
Browse files Browse the repository at this point in the history
  • Loading branch information
SaltyChiang committed Mar 4, 2024
1 parent 7989048 commit 19a6632
Showing 1 changed file with 53 additions and 52 deletions.
105 changes: 53 additions & 52 deletions pyquda/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def _setEnviron(env, key, value):
_setEnviron(f"QUDA_{key.upper()}", key, kwargs[key])


def _initEnvironWarn(**kwargs):
def _initEnvironPath(**kwargs):
def _setEnviron(env, key, value):
if value is not None:
if env in environ:
Expand All @@ -93,31 +93,31 @@ def init(
anisotropy: float = None,
backend: Literal["cupy", "torch"] = "cupy",
*,
resource_path: str = None,
enable_tuning: str = None,
enable_tuning_shared: str = None,
tune_version_check: str = None,
tuning_rank: str = None,
profile_output_base: str = None,
enable_target_profile: str = None,
do_not_profile: str = None,
enable_trace: str = None,
enable_mps: str = None,
rank_verbosity: str = None,
enable_managed_memory: str = None,
enable_managed_prefetch: str = None,
enable_device_memory_pool: str = None,
enable_pinned_memory_pool: str = None,
enable_p2p: str = None,
enable_p2p_max_access_rank: str = None,
enable_gdr: str = None,
enable_gdr_blacklist: str = None,
enable_nvshmem: str = None,
deterministic_reduce: str = None,
allow_jit: str = None,
device_reset: str = None,
reorder_location: str = None,
enable_force_monitor: str = None,
resource_path: str = "",
rank_verbosity: List[int] = [0],
enable_mps: bool = False,
enable_gdr: bool = False,
enable_gdr_blacklist: List[int] = [],
enable_p2p: Literal[-3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7] = 3,
enable_p2p_max_access_rank: int = 0x7FFFFFFF,
enable_nvshmem: bool = True,
allow_jit: bool = False,
reorder_location: Literal["GPU", "CPU"] = "GPU",
enable_tuning: bool = True,
enable_tuning_shared: bool = True,
tune_version_check: bool = True,
tuning_rank: int = 0,
profile_output_base: str = "",
enable_target_profile: List[int] = [],
do_not_profile: bool = False,
enable_trace: Literal[0, 1, 2] = 0,
enable_force_monitor: bool = False,
enable_device_memory_pool: bool = True,
enable_pinned_memory_pool: bool = True,
enable_managed_memory: bool = False,
enable_managed_prefetch: bool = False,
deterministic_reduce: bool = False,
device_reset: bool = False,
):
"""
Initialize MPI along with the QUDA library.
Expand All @@ -137,33 +137,34 @@ def init(
assert _MPI_SIZE == Gx * Gy * Gz * Gt
printRoot(f"INFO: Using gird {_GRID_SIZE}")

_initEnvironWarn(resource_path=resource_path)
_initEnvironPath(resource_path=resource_path if resource_path != "" else None)
_initEnviron(
resource_path=resource_path,
enable_tuning=enable_tuning,
enable_tuning_shared=enable_tuning_shared,
tune_version_check=tune_version_check,
tuning_rank=tuning_rank,
profile_output_base=profile_output_base,
enable_target_profile=enable_target_profile,
do_not_profile=do_not_profile,
enable_trace=enable_trace,
enable_mps=enable_mps,
rank_verbosity=rank_verbosity,
enable_managed_memory=enable_managed_memory,
enable_managed_prefetch=enable_managed_prefetch,
enable_device_memory_pool=enable_device_memory_pool,
enable_pinned_memory_pool=enable_pinned_memory_pool,
enable_p2p=enable_p2p,
enable_p2p_max_access_rank=enable_p2p_max_access_rank,
enable_gdr=enable_gdr,
enable_gdr_blacklist=enable_gdr_blacklist,
enable_nvshmem=enable_nvshmem,
deterministic_reduce=deterministic_reduce,
allow_jit=allow_jit,
device_reset=device_reset,
reorder_location=reorder_location,
enable_force_monitor=enable_force_monitor,
rank_verbosity=",".join(rank_verbosity) if rank_verbosity != [0] else None,
enable_mps="1" if enable_mps else None,
enable_gdr="1" if enable_gdr else None,
enable_gdr_blacklist=",".join(enable_gdr_blacklist) if enable_gdr_blacklist != [] else None,
enable_p2p=str(enable_p2p) if enable_p2p != 3 else None,
enable_p2p_max_access_rank=(
str(enable_p2p_max_access_rank) if enable_p2p_max_access_rank < 0x7FFFFFFF else None
),
enable_nvshmem="0" if not enable_nvshmem else None,
allow_jit="1" if allow_jit else None,
reorder_location="CPU" if reorder_location == "CPU" else None,
enable_tuning="0" if not enable_tuning else None,
enable_tuning_shared="0" if not enable_tuning_shared else None,
tune_version_check="0" if not tune_version_check else None,
tuning_rank=str(tuning_rank) if tuning_rank else None,
profile_output_base=profile_output_base if profile_output_base != "" else None,
enable_target_profile=",".join(enable_target_profile) if enable_target_profile != [] else None,
do_not_profile="1" if do_not_profile else None,
enable_trace="1" if enable_trace else None,
enable_force_monitor="1" if enable_force_monitor else None,
enable_device_memory_pool="0" if not enable_device_memory_pool else None,
enable_pinned_memory_pool="0" if not enable_pinned_memory_pool else None,
enable_managed_memory="1" if enable_managed_memory else None,
enable_managed_prefetch="1" if enable_managed_prefetch else None,
deterministic_reduce="1" if deterministic_reduce else None,
device_reset="1" if device_reset else None,
)

global _DEFAULT_LATTICE, _CUDA_BACKEND, _GPUID, _COMPUTE_CAPABILITY
Expand Down

0 comments on commit 19a6632

Please sign in to comment.