Skip to content

Commit 7bf0031

Browse files
committed
Use numpy.ndarray as input of QUDA functions.
1 parent 009affe commit 7bf0031

File tree

13 files changed

+158
-346
lines changed

13 files changed

+158
-346
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
[build-system]
2-
requires = ["setuptools", "wheel", "Cython"]
2+
requires = ["setuptools", "wheel", "Cython", "numpy"]

pyquda/action/one_flavor_clover.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import numpy
22

3-
from ..pointer import Pointers, ndarrayPointer
3+
from ..pointer import Pointers
44
from ..pyquda import computeCloverForceQuda, invertMultiShiftQuda, loadCloverQuda
55
from ..enum_quda import (
66
QUDA_MAX_MULTI_SHIFT,
@@ -140,7 +140,7 @@ def force(self, dt, new_gauge: bool):
140140
nullptr,
141141
dt,
142142
xx.even_ptrs,
143-
ndarrayPointer(numpy.array(residue_inv_square_root, "<f8")),
143+
numpy.array(residue_inv_square_root, "<f8"),
144144
self.kappa2,
145145
self.ck,
146146
num_offset,

pyquda/action/symanzik_gauge.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import numpy
22

3-
from ..pointer import Pointers, ndarrayPointer
3+
from ..pointer import Pointers
44
from ..pyquda import computeGaugeLoopTraceQuda, computeGaugeForceQuda
55
from ..field import Nc, LatticeInfo
66
from ..dirac.pure_gauge import PureGauge
@@ -125,10 +125,10 @@ def __init__(self, latt_info: LatticeInfo, beta: float, u_0: float):
125125
def action(self) -> float:
126126
traces = numpy.zeros((self.num_paths), "<c16")
127127
computeGaugeLoopTraceQuda(
128-
ndarrayPointer(traces),
129-
ndarrayPointer(self.path),
130-
ndarrayPointer(self.lengths),
131-
ndarrayPointer(self.coeffs),
128+
traces,
129+
self.path,
130+
self.lengths,
131+
self.coeffs,
132132
self.num_paths,
133133
self.max_length,
134134
1,
@@ -139,9 +139,9 @@ def force(self, dt: float):
139139
computeGaugeForceQuda(
140140
nullptr,
141141
nullptr,
142-
ndarrayPointer(self.fpath),
143-
ndarrayPointer(self.flengths),
144-
ndarrayPointer(self.fcoeffs),
142+
self.fpath,
143+
self.flengths,
144+
self.fcoeffs,
145145
self.num_fpaths,
146146
self.max_flength,
147147
dt,

pyquda/action/two_flavor_clover.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def force(self, dt, new_gauge: bool):
6464
nullptr,
6565
dt,
6666
ndarrayPointer(self.phi.even.reshape(1, -1), True),
67-
ndarrayPointer(numpy.array([1.0], "<f8")),
67+
numpy.array([1.0], "<f8"),
6868
self.kappa2,
6969
self.ck,
7070
1,

pyquda/action/wilson_gauge.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import numpy
22

3-
from ..pointer import Pointers, ndarrayPointer
3+
from ..pointer import Pointers
44
from ..pyquda import computeGaugeLoopTraceQuda, computeGaugeForceQuda
55
from ..field import Nc, LatticeInfo
66
from ..dirac.pure_gauge import PureGauge
@@ -101,10 +101,10 @@ def __init__(self, latt_info: LatticeInfo, beta: float, u_0: float):
101101
def action(self) -> float:
102102
traces = numpy.zeros((self.num_paths), "<c16")
103103
computeGaugeLoopTraceQuda(
104-
ndarrayPointer(traces),
105-
ndarrayPointer(self.path),
106-
ndarrayPointer(self.lengths),
107-
ndarrayPointer(self.coeffs),
104+
traces,
105+
self.path,
106+
self.lengths,
107+
self.coeffs,
108108
self.num_paths,
109109
self.max_length,
110110
1,
@@ -115,9 +115,9 @@ def force(self, dt: float):
115115
computeGaugeForceQuda(
116116
nullptr,
117117
nullptr,
118-
ndarrayPointer(self.fpath),
119-
ndarrayPointer(self.flengths),
120-
ndarrayPointer(self.fcoeffs),
118+
self.fpath,
119+
self.flengths,
120+
self.fcoeffs,
121121
self.num_fpaths,
122122
self.max_flength,
123123
dt,

pyquda/dirac/hisq.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import numpy
44

5-
from ..pointer import ndarrayPointer, Pointers
5+
from ..pointer import Pointers
66
from ..pyquda import newMultigridQuda, destroyMultigridQuda, computeKSLinkQuda
77
from ..field import LatticeInfo, LatticeGauge
88
from ..enum_quda import QudaDslashType, QudaInverterType, QudaReconstructType, QudaPrecision
@@ -147,15 +147,15 @@ def computeFatLong(self, gauge: LatticeGauge):
147147
nullptr,
148148
ulink.data_ptrs,
149149
inlink.data_ptrs,
150-
ndarrayPointer(self.fat7_coeff),
150+
self.fat7_coeff,
151151
self.gauge_param,
152152
)
153153
computeKSLinkQuda(
154154
fatlink.data_ptrs,
155155
longlink.data_ptrs,
156156
nullptr,
157157
ulink.data_ptrs,
158-
ndarrayPointer(self.level2_coeff),
158+
self.level2_coeff,
159159
self.gauge_param,
160160
)
161161

pyquda/dirac/pure_gauge.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Literal
1+
from typing import Literal
22

33
import numpy
44
from numpy.typing import NDArray
@@ -142,9 +142,9 @@ def path(
142142
computeGaugePathQuda(
143143
gauge.data_ptrs,
144144
gauge.data_ptrs,
145-
ndarrayPointer(numpy.ascontiguousarray(input_path_buf)),
146-
ndarrayPointer(numpy.ascontiguousarray(path_length)),
147-
ndarrayPointer(numpy.ascontiguousarray(loop_coeff)),
145+
input_path_buf,
146+
path_length,
147+
loop_coeff,
148148
input_path_buf.shape[1],
149149
input_path_buf.shape[2],
150150
1.0,
@@ -237,11 +237,12 @@ def qcharge(self):
237237
return self.obs_param.qcharge
238238

239239
def qchargeDensity(self):
240-
# self.obs_param.qcharge_density =
241-
# self.obs_param.compute_qcharge_density = QudaBoolean.QUDA_BOOLEAN_TRUE
242-
# performGaugeSmearQuda(self.obs_param)
243-
# self.obs_param.compute_qcharge_density = QudaBoolean.QUDA_BOOLEAN_TRUE
244-
raise NotImplementedError("qchargeDensity not implemented. Confusing size of ndarray.")
240+
retval = numpy.zeros((self.latt_info.volume), "<c16")
241+
self.obs_param.qcharge_density = ndarrayPointer(retval, True)
242+
self.obs_param.compute_qcharge_density = QudaBoolean.QUDA_BOOLEAN_TRUE
243+
gaugeObservablesQuda(self.obs_param)
244+
self.obs_param.compute_qcharge_density = QudaBoolean.QUDA_BOOLEAN_TRUE
245+
return retval
245246

246247
def gauss(self, seed: int, sigma: float):
247248
gaussGaugeQuda(seed, sigma)

0 commit comments

Comments
 (0)