Skip to content

Commit

Permalink
Add core.invertPropagator and core.invertStaggeredPropagator.
Browse files Browse the repository at this point in the history
  • Loading branch information
SaltyChiang committed Jul 24, 2024
1 parent a8ab3a1 commit 27b2910
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 20 deletions.
67 changes: 47 additions & 20 deletions pyquda/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,46 +22,73 @@


def invert(
dslash: Dirac,
dirac: Dirac,
source_type: Literal["point", "wall", "volume", "momentum", "colorvector"],
t_srce: Union[int, List[int]],
t_srce: Union[List[int], int, None],
source_phase=None,
restart: int = 0,
):
latt_info = dslash.latt_info
latt_info = dirac.latt_info

prop = LatticePropagator(latt_info)
propag = LatticePropagator(latt_info)
for spin in range(Ns):
for color in range(Nc):
b = source(latt_info, source_type, t_srce, spin, color, source_phase)
x = dslash.invert(b)
for _ in range(restart):
r = b - dslash.mat(x)
x += dslash.invert(r)
prop.setFermion(x, spin, color)
x = dirac.invertRestart(b, restart)
propag.setFermion(x, spin, color)

return prop
return propag


def invertPropagator(
dirac: Dirac,
source_propag: Union[None, LatticePropagator] = None,
restart: int = 0,
):
latt_info = dirac.latt_info

propag = LatticePropagator(latt_info)
for spin in range(Ns):
for color in range(Nc):
b = source_propag.getFermion(spin, color)
x = dirac.invertRestart(b, restart)
propag.setFermion(x, spin, color)

return propag


def invertStaggered(
dslash: StaggeredDirac,
dirac: StaggeredDirac,
source_type: Literal["point", "wall", "volume", "momentum", "colorvector"],
t_srce: Union[int, List[int]],
t_srce: Union[List[int], int, None],
source_phase=None,
restart: int = 0,
):
latt_info = dslash.latt_info
latt_info = dirac.latt_info

prop = LatticeStaggeredPropagator(latt_info)
propag = LatticeStaggeredPropagator(latt_info)
for color in range(Nc):
b = source(latt_info, source_type, t_srce, None, color, source_phase)
x = dslash.invert(b)
for _ in range(restart):
r = b - dslash.mat(x)
x += dslash.invert(r)
prop.setFermion(x, color)
x = dirac.invertRestart(b, restart)
propag.setFermion(x, color)

return propag


def invertStaggeredPropagator(
dirac: StaggeredDirac,
source_propag: LatticeStaggeredPropagator,
restart: int = 0,
):
latt_info = dirac.latt_info

propag = LatticeStaggeredPropagator(latt_info)
for color in range(Nc):
b = source_propag.getFermion(color)
x = dirac.invertRestart(b, restart)
propag.setFermion(x, color)

return prop
return propag


def gatherLattice(data: numpy.ndarray, axes: List[int], reduce_op: Literal["sum", "mean"] = "sum", root: int = 0):
Expand Down
14 changes: 14 additions & 0 deletions pyquda/dirac/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,13 @@ def invert(self, b: LatticeFermion):
self.performance()
return x

def invertRestart(self, b: LatticeFermion, restart: int):
x = self.invert(b)
for _ in range(restart):
r = b - self.mat(x)
x += self.invert(r)
return x

def mat(self, x: LatticeFermion):
b = LatticeFermion(x.latt_info)
MatQuda(b.data_ptr, x.data_ptr, self.invert_param)
Expand Down Expand Up @@ -220,6 +227,13 @@ def invert(self, b: LatticeStaggeredFermion):
self.performance()
return x

def invertRestart(self, b: LatticeStaggeredFermion, restart: int):
x = self.invert(b)
for _ in range(restart):
r = b - self.mat(x)
x += self.invert(r)
return x

def mat(self, x: LatticeStaggeredFermion):
b = LatticeStaggeredFermion(x.latt_info)
MatQuda(b.data_ptr, x.data_ptr, self.invert_param)
Expand Down

0 comments on commit 27b2910

Please sign in to comment.