Skip to content

Commit a8fccc3

Browse files
committed
Rename cb2 to evenodd.
Bug fix for GPT interface.
1 parent 230c2cf commit a8fccc3

File tree

12 files changed

+91
-73
lines changed

12 files changed

+91
-73
lines changed

pyquda_core/pyquda/__init__.py

Lines changed: 29 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from mpi4py import MPI
77

88
from ._version import __version__ # noqa: F401
9-
from . import pyquda as quda
109
from .field import LatticeInfo
1110

1211

@@ -141,7 +140,7 @@ def _getDefaultGrid(mpi_size: int, latt_size: List[int]):
141140
return min(min_grid)
142141

143142

144-
def _initEnviron(**kwargs):
143+
def _setEnviron(**kwargs):
145144
def _setEnviron(env, key, value):
146145
if value is not None:
147146
if env in environ:
@@ -154,7 +153,7 @@ def _setEnviron(env, key, value):
154153
_setEnviron(f"QUDA_{key.upper()}", key, kwargs[key])
155154

156155

157-
def _initEnvironWarn(**kwargs):
156+
def _setEnvironWarn(**kwargs):
158157
def _setEnviron(env, key, value):
159158
if value is not None:
160159
if env in environ:
@@ -172,9 +171,8 @@ def _setEnviron(env, key, value):
172171

173172
def initGPU(backend: Literal["numpy", "cupy", "torch"] = None, gpuid: int = -1):
174173
global _CUDA_BACKEND, _HIP, _GPUID, _COMPUTE_CAPABILITY
175-
176174
if isGridInitialized():
177-
_MPI_LOGGER.critical("initGPU should be called before init", RuntimeError)
175+
_MPI_LOGGER.critical("initGPU should be called before initGrid", RuntimeError)
178176
if _GPUID < 0:
179177
from platform import node as gethostname
180178

@@ -239,17 +237,31 @@ def initGPU(backend: Literal["numpy", "cupy", "torch"] = None, gpuid: int = -1):
239237
_MPI_LOGGER.warning("GPU is already initialized", RuntimeWarning)
240238

241239

242-
def initQUDA(grid_size: List[int], gpuid: int):
240+
def initGrid(grid_size: List[int]):
241+
global _GRID_SIZE, _GRID_COORD
242+
if _GRID_SIZE is None:
243+
Gx, Gy, Gz, Gt = grid_size
244+
if _MPI_SIZE != Gx * Gy * Gz * Gt:
245+
_MPI_LOGGER.critical(f"The MPI size {_MPI_SIZE} does not match the grid size {grid_size}", ValueError)
246+
_GRID_SIZE = [Gx, Gy, Gz, Gt]
247+
_GRID_COORD = getCoordFromRank(_MPI_RANK, _GRID_SIZE)
248+
_MPI_LOGGER.info(f"Using the grid size {_GRID_SIZE}")
249+
else:
250+
_MPI_LOGGER.warning("Grid is already initialized", RuntimeWarning)
251+
252+
253+
def initQUDA(grid_size: List[int], gpuid: int, use_quda_allocator: bool = False):
243254
import atexit
255+
from . import pyquda as quda, malloc_pyquda
244256

245-
# if _CUDA_BACKEND == "cupy":
246-
# import cupy
247-
# from . import malloc_pyquda
257+
if use_quda_allocator:
258+
if _CUDA_BACKEND == "cupy":
259+
import cupy
248260

249-
# allocator = cupy.cuda.PythonFunctionAllocator(
250-
# malloc_pyquda.pyquda_device_malloc, malloc_pyquda.pyquda_device_free
251-
# )
252-
# cupy.cuda.set_allocator(allocator.malloc)
261+
allocator = cupy.cuda.PythonFunctionAllocator(
262+
malloc_pyquda.pyquda_device_malloc, malloc_pyquda.pyquda_device_free
263+
)
264+
cupy.cuda.set_allocator(allocator.malloc)
253265

254266
quda.initCommsGridQuda(4, grid_size)
255267
quda.initQuda(gpuid)
@@ -293,28 +305,23 @@ def init(
293305
"""
294306
Initialize MPI along with the QUDA library.
295307
"""
296-
global _GRID_SIZE, _GRID_COORD, _DEFAULT_LATTICE
308+
global _DEFAULT_LATTICE
297309
if _GRID_SIZE is None:
298310
initGPU(backend)
299311

300312
use_default_grid = grid_size is None and latt_size is not None
301313
use_default_latt = latt_size is not None and t_boundary is not None and anisotropy is not None
302314
if use_default_grid:
303315
grid_size = _getDefaultGrid(_MPI_SIZE, latt_size)
304-
Gx, Gy, Gz, Gt = grid_size if grid_size is not None else [1, 1, 1, 1]
305-
if _MPI_SIZE != Gx * Gy * Gz * Gt:
306-
_MPI_LOGGER.critical(f"The MPI size {_MPI_SIZE} does not match the grid size {grid_size}", ValueError)
307-
_GRID_SIZE = [Gx, Gy, Gz, Gt]
308-
_GRID_COORD = getCoordFromRank(_MPI_RANK, _GRID_SIZE)
309-
_MPI_LOGGER.info(f"Using the grid size {_GRID_SIZE}")
316+
initGrid(grid_size if grid_size is not None else [1, 1, 1, 1])
310317
if use_default_grid and not use_default_latt:
311318
_MPI_LOGGER.info(f"Using the lattice size {latt_size} only for getting the default grid size {_GRID_SIZE}")
312319
if use_default_latt:
313320
_DEFAULT_LATTICE = LatticeInfo(latt_size, t_boundary, anisotropy)
314321
_MPI_LOGGER.info(f"Using the default lattice LatticeInfo({latt_size}, {t_boundary}, {anisotropy})")
315322

316-
_initEnvironWarn(resource_path=resource_path if resource_path != "" else None)
317-
_initEnviron(
323+
_setEnvironWarn(resource_path=resource_path if resource_path != "" else None)
324+
_setEnviron(
318325
rank_verbosity=",".join(rank_verbosity) if rank_verbosity != [0] else None,
319326
enable_mps="1" if enable_mps else None,
320327
enable_gdr="1" if enable_gdr else None,

pyquda_core/pyquda/_version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "0.9.10"
1+
__version__ = "0.9.11"

pyquda_core/pyquda/field.py

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -130,40 +130,47 @@ def lexico(data: numpy.ndarray, axes: List[int], dtype=None):
130130
Npre = int(numpy.prod(shape[: axes[0]]))
131131
Nsuf = int(numpy.prod(shape[axes[-1] + 1 :]))
132132
dtype = data.dtype if dtype is None else dtype
133-
data_cb2 = data.reshape(Npre, 2, Lt, Lz, Ly, Lx // 2, Nsuf)
133+
data_evenodd = data.reshape(Npre, 2, Lt, Lz, Ly, Lx // 2, Nsuf)
134134
data_lexico = numpy.zeros((Npre, Lt, Lz, Ly, Lx, Nsuf), dtype)
135135
for t in range(Lt):
136136
for z in range(Lz):
137137
for y in range(Ly):
138138
eo = (t + z + y) % 2
139139
if eo == 0:
140-
data_lexico[:, t, z, y, 0::2] = data_cb2[:, 0, t, z, y, :]
141-
data_lexico[:, t, z, y, 1::2] = data_cb2[:, 1, t, z, y, :]
140+
data_lexico[:, t, z, y, 0::2] = data_evenodd[:, 0, t, z, y, :]
141+
data_lexico[:, t, z, y, 1::2] = data_evenodd[:, 1, t, z, y, :]
142142
else:
143-
data_lexico[:, t, z, y, 1::2] = data_cb2[:, 0, t, z, y, :]
144-
data_lexico[:, t, z, y, 0::2] = data_cb2[:, 1, t, z, y, :]
143+
data_lexico[:, t, z, y, 1::2] = data_evenodd[:, 0, t, z, y, :]
144+
data_lexico[:, t, z, y, 0::2] = data_evenodd[:, 1, t, z, y, :]
145145
return data_lexico.reshape(*shape[: axes[0]], Lt, Lz, Ly, Lx, *shape[axes[-1] + 1 :])
146146

147147

148-
def cb2(data: numpy.ndarray, axes: List[int], dtype=None):
148+
def evenodd(data: numpy.ndarray, axes: List[int], dtype=None):
149149
shape = data.shape
150150
Lt, Lz, Ly, Lx = [shape[axis] for axis in axes]
151151
Npre = int(numpy.prod(shape[: axes[0]]))
152152
Nsuf = int(numpy.prod(shape[axes[-1] + 1 :]))
153153
dtype = data.dtype if dtype is None else dtype
154154
data_lexico = data.reshape(Npre, Lt, Lz, Ly, Lx, Nsuf)
155-
data_cb2 = numpy.zeros((Npre, 2, Lt, Lz, Ly, Lx // 2, Nsuf), dtype)
155+
data_evenodd = numpy.zeros((Npre, 2, Lt, Lz, Ly, Lx // 2, Nsuf), dtype)
156156
for t in range(Lt):
157157
for z in range(Lz):
158158
for y in range(Ly):
159159
eo = (t + z + y) % 2
160160
if eo == 0:
161-
data_cb2[:, 0, t, z, y, :] = data_lexico[:, t, z, y, 0::2]
162-
data_cb2[:, 1, t, z, y, :] = data_lexico[:, t, z, y, 1::2]
161+
data_evenodd[:, 0, t, z, y, :] = data_lexico[:, t, z, y, 0::2]
162+
data_evenodd[:, 1, t, z, y, :] = data_lexico[:, t, z, y, 1::2]
163163
else:
164-
data_cb2[:, 0, t, z, y, :] = data_lexico[:, t, z, y, 1::2]
165-
data_cb2[:, 1, t, z, y, :] = data_lexico[:, t, z, y, 0::2]
166-
return data_cb2.reshape(*shape[: axes[0]], 2, Lt, Lz, Ly, Lx // 2, *shape[axes[-1] + 1 :])
164+
data_evenodd[:, 0, t, z, y, :] = data_lexico[:, t, z, y, 1::2]
165+
data_evenodd[:, 1, t, z, y, :] = data_lexico[:, t, z, y, 0::2]
166+
return data_evenodd.reshape(*shape[: axes[0]], 2, Lt, Lz, Ly, Lx // 2, *shape[axes[-1] + 1 :])
167+
168+
169+
def cb2(data: numpy.ndarray, axes: List[int], dtype=None):
170+
from . import getLogger
171+
172+
getLogger().warning("cb2 is deprecated, use evenodd instead", DeprecationWarning)
173+
return evenodd(data, axes, dtype)
167174

168175

169176
def checksum(latt_info: Union[LatticeInfo, GeneralInfo], data: numpy.ndarray) -> Tuple[int, int]:
@@ -675,9 +682,9 @@ def load(
675682
if Nc is not None:
676683
latt_info.Nc = Nc
677684
if not issubclass(cls, MultiField):
678-
retval = cls(latt_info, cb2(value, [0, 1, 2, 3]))
685+
retval = cls(latt_info, evenodd(value, [0, 1, 2, 3]))
679686
else:
680-
retval = cls(latt_info, len(label), numpy.asarray([cb2(data, [0, 1, 2, 3]) for data in value]))
687+
retval = cls(latt_info, len(label), numpy.asarray([evenodd(data, [0, 1, 2, 3]) for data in value]))
681688
secs = perf_counter() - s
682689
getLogger().debug(f"Loaded {filename} in {secs:.3f} secs, {gbytes / secs:.3f} GB/s")
683690
return retval

pyquda_utils/core.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from pyquda import (
66
initGPU,
7+
initGrid,
78
initQUDA,
89
init,
910
getCoordFromRank,
@@ -15,6 +16,7 @@
1516
getGridCoord,
1617
setDefaultLattice,
1718
getDefaultLattice,
19+
getCUDABackend,
1820
getLogger,
1921
setLoggerLevel,
2022
dirac as fermion,
@@ -41,7 +43,8 @@
4143
LatticePropagator,
4244
LatticeStaggeredPropagator,
4345
lexico,
44-
cb2,
46+
evenodd,
47+
evenodd as cb2,
4548
)
4649
from pyquda.dirac.abstract import Multigrid, FermionDirac, StaggeredFermionDirac
4750

pyquda_utils/deprecated.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from typing import List
22

3-
from pyquda import getLogger, getGridSize, quda, enum_quda
3+
from pyquda import getLogger, getGridSize, pyquda as quda, enum_quda
44
from pyquda.field import LatticeFermion, LatticeGauge, LatticeInfo, LatticePropagator, Nc, Ns
55
from pyquda.dirac.abstract import FermionDirac
66

@@ -101,11 +101,11 @@ def getDslash(
101101
latt_info = LatticeInfo([Lx, Ly, Lz, Lt], t_boundary, xi)
102102

103103
if clover_csw != 0.0:
104-
from .dirac.clover_wilson import CloverWilsonDirac
104+
from pyquda.dirac.clover_wilson import CloverWilsonDirac
105105

106106
return CloverWilsonDirac(latt_info, mass, tol, maxiter, clover_csw, clover_xi, geo_block_size)
107107
else:
108-
from .dirac.wilson import WilsonDirac
108+
from pyquda.dirac.wilson import WilsonDirac
109109

110110
return WilsonDirac(latt_info, mass, tol, maxiter, geo_block_size)
111111

@@ -131,6 +131,6 @@ def getStaggeredDslash(
131131
t_boundary = 1
132132
latt_info = LatticeInfo([Lx, Ly, Lz, Lt], t_boundary, 1.0)
133133

134-
from .dirac.hisq import HISQDirac
134+
from pyquda.dirac.hisq import HISQDirac
135135

136136
return HISQDirac(latt_info, mass, tol, maxiter, naik_epsilon, None)

pyquda_utils/gpt.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,17 @@
11
from typing import List
22
import numpy
33

4-
from pyquda import getSublatticeSize, getGridSize
5-
from pyquda.field import cb2, LatticeGauge, LatticeInfo, LatticePropagator
4+
from .core import evenodd, getGridSize, LatticeGauge, LatticeInfo, LatticePropagator
65

76
import gpt as g
87

98

109
def LatticeInfoGPT(grid: g.grid, gen_simd_width: int):
1110
assert getGridSize() == grid.mpi
12-
sublatt_size = getSublatticeSize(grid.fdimensions, grid.mpi)
11+
GLx, GLy, GLz, GLt = grid.fdimensions
12+
Gx, Gy, Gz, Gt = grid.mpi
13+
Lx, Ly, Lz, Lt = GLx // Gx, GLy // Gy, GLz // Gz, GLt // Gt
14+
sublatt_size = [Lx, Ly, Lz, Lt]
1315
Nd = len(sublatt_size)
1416
precision = grid.precision.nbytes
1517
n_simd = gen_simd_width // (2 * precision)
@@ -32,7 +34,7 @@ def LatticeGaugeGPT(lattice: List[g.lattice], gen_simd_width: int, gauge: Lattic
3234
value = []
3335
for index in range(latt_info.Nd):
3436
value.append(
35-
cb2(
37+
evenodd(
3638
numpy.asarray(lattice[index].mview()[0])
3739
.view(f"<c{2 * gpt_prec}")
3840
.reshape(*gpt_latt[::-1], Nc, Nc, *gpt_simd[::-1])
@@ -49,7 +51,8 @@ def LatticeGaugeGPT(lattice: List[g.lattice], gen_simd_width: int, gauge: Lattic
4951
for index in range(latt_info.Nd):
5052
gpt_shape = [i for sl in zip(gpt_simd, gpt_latt) for i in sl]
5153
lattice[index].mview()[0][:] = (
52-
gauge[index].lexico()
54+
gauge[index]
55+
.lexico()
5356
.astype(f"<c{2 * gpt_prec}")
5457
.reshape(*gpt_shape, Nc, Nc)
5558
.transpose(1, 3, 5, 7, 8, 9, 0, 2, 4, 6)
@@ -65,7 +68,7 @@ def LatticePropagatorGPT(lattice: g.lattice, gen_simd_width: int, propagator: La
6568
Ns, Nc = latt_info.Ns, latt_info.Nc
6669
assert lattice.describe().startswith(f"ot_matrix_spin_color({Ns},{Nc})")
6770
if propagator is None:
68-
value = cb2(
71+
value = evenodd(
6972
numpy.asarray(lattice.mview()[0])
7073
.view(f"<c{2 * gpt_prec}")
7174
.reshape(*gpt_latt[::-1], Ns, Ns, Nc, Nc, *gpt_simd[::-1])

0 commit comments

Comments
 (0)