Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 30 additions & 8 deletions pyzag/chunktime.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,12 @@
# pylint: disable=abstract-method

"""
Functions and objects to help with blocked/chunked time integration.
Functions and objects to help with blocked/chunked time integration.

These include:
1. Sparse matrix classes for banded systems
2. General sparse matrix classes
3. Specialized solver routines working with banded systems
These include:
1. Sparse matrix classes for banded systems
2. General sparse matrix classes
3. Specialized solver routines working with banded systems
"""

import warnings
Expand All @@ -40,8 +40,18 @@
from torch.nn.functional import pad
import numpy as np

from abc import ABC, abstractmethod

class ChunkNewtonRaphson:

class NonlinearSolver(ABC):
"""Base class for nonlinear solvers working with chunked operators"""

@abstractmethod
def solve(self, fn, x0) -> torch.Tensor:
pass


class ChunkNewtonRaphson(NonlinearSolver):
"""Solve a nonlinear system with Newton's method where the residual and Jacobian are presented as chunked operators

Keyword Args:
Expand Down Expand Up @@ -203,6 +213,8 @@ def step(self, x, J, fn, R0, take_step):

f = torch.ones_like(nR0)

R, nR = None, None

for _ in range(self.linesearch_iter):
x[:, final_steps] = x0 - f.unsqueeze(-1).unsqueeze(0) * dx

Expand All @@ -217,6 +229,7 @@ def step(self, x, J, fn, R0, take_step):

f = torch.where(decreasing, f, f * self.alpha)

assert R is not None and nR is not None
return x, R, J, nR


Expand Down Expand Up @@ -286,6 +299,10 @@ def shape(self):
"""
return (self.sbat, self.n, self.n)

@abstractmethod
def matvec(self, v) -> torch.Tensor:
pass


class LUFactorization(BidiagonalOperator):
"""A factorization that uses the LU decomposition of A
Expand Down Expand Up @@ -601,7 +618,12 @@ class BidiagonalForwardOperator(BidiagonalOperator):
storing the nblk-1 off diagonal blocks
"""

def __init__(self, *args, inverse_operator=BidiagonalThomasFactorization, **kwargs):
def __init__(
self,
*args,
inverse_operator: type[BidiagonalOperator] = BidiagonalThomasFactorization,
**kwargs,
):
super().__init__(*args, **kwargs)
self.inverse_operator = inverse_operator

Expand All @@ -621,7 +643,7 @@ def forward(self, v):
"""
return self.matvec(v)

def matvec(self, v):
def matvec(self, v) -> torch.Tensor:
"""
:math:`A \\cdot v` in an efficient manner

Expand Down
42 changes: 32 additions & 10 deletions pyzag/nonlinear.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import torch

from pyzag import chunktime
from abc import ABC, abstractmethod


class NonlinearRecursiveFunction(torch.nn.Module):
Expand Down Expand Up @@ -124,7 +125,15 @@ def predict(self, results, k, kinc):
return self.history[k : k + kinc]


class ZeroPredictor:
class Predictor(ABC):
"""Base class for predictors"""

@abstractmethod
def predict(self, results, k, kinc) -> torch.Tensor:
pass


class ZeroPredictor(Predictor):
"""Predict steps just using zeros"""

def predict(self, results, k, kinc):
Expand All @@ -138,7 +147,7 @@ def predict(self, results, k, kinc):
return torch.zeros_like(results[k : k + kinc])


class PreviousStepsPredictor:
class PreviousStepsPredictor(Predictor):
"""Predict by providing the values from the previous chunk of steps steps"""

def predict(self, results, k, kinc):
Expand All @@ -158,7 +167,7 @@ def predict(self, results, k, kinc):
return results[(k - kinc) : k]


class LastStepPredictor:
class LastStepPredictor(Predictor):
"""Predict by providing the values from the previous single step"""

def predict(self, results, k, kinc):
Expand All @@ -175,7 +184,7 @@ def predict(self, results, k, kinc):
return results[k - 1].unsqueeze(0).expand((kinc,) + results.shape[1:])


class StepExtrapolatingPredictor:
class StepExtrapolatingPredictor(Predictor):
"""Predict by extrapolating using the previous *chunks* of steps"""

def predict(self, results, k, kinc):
Expand All @@ -196,7 +205,7 @@ def predict(self, results, k, kinc):
return dinc.unsqueeze(0).expand((kinc,) + results.shape[1:])


class ExtrapolatingPredictor:
class ExtrapolatingPredictor(Predictor):
"""Predict by extrapolating the values from the previous *single* steps"""

def predict(self, results, k, kinc):
Expand Down Expand Up @@ -342,10 +351,12 @@ class RecursiveNonlinearEquationSolver(torch.nn.Module):
def __init__(
self,
func,
step_generator=StepGenerator(1),
predictor=ZeroPredictor(),
direct_solve_operator=chunktime.BidiagonalThomasFactorization,
nonlinear_solver=chunktime.ChunkNewtonRaphson(),
step_generator: StepGenerator = StepGenerator(1),
predictor: Predictor = ZeroPredictor(),
direct_solve_operator: type[
chunktime.BidiagonalOperator
] = chunktime.BidiagonalThomasFactorization,
nonlinear_solver: chunktime.NonlinearSolver = chunktime.ChunkNewtonRaphson(),
callbacks=None,
convert_nan_gradients=True,
):
Expand Down Expand Up @@ -457,6 +468,10 @@ def rewind(self, output_grad):
)

# Loop backwards through time
if self.result is None:
raise ValueError("No cached result found for adjoint rewind")

adjoint = None
for k1, k2 in self.step_generator(len(self.result)).reverse():
# Get our block of the results
with torch.enable_grad():
Expand All @@ -479,6 +494,11 @@ def rewind(self, output_grad):
grad_result, adjoint, R[:1], retain=True
)

if adjoint is None:
raise NotImplementedError(
"Currently only supports adjoint rewind for a single block"
)

# Do the block adjoint update
adjoint = self.block_update_adjoint(
J, output_grad[k1:k2].flip(0), adjoint[-1]
Expand All @@ -488,6 +508,8 @@ def rewind(self, output_grad):
with torch.enable_grad():
grad_result = self.accumulate(grad_result, adjoint, R[1:])

if adjoint is None:
raise ValueError("No adjoint values calculated during rewind")
if self.convert_nan_gradients:
return tuple(torch.nan_to_num(g) for g in grad_result), torch.nan_to_num(
adjoint[-1]
Expand Down Expand Up @@ -570,7 +592,7 @@ def forward(ctx, solver, y0, n, forces, *params):
return y

@staticmethod
def backward(ctx, output_grad):
def backward(ctx, *output_grad):
with torch.no_grad():
grad_res, adj_last = ctx.solver.rewind(output_grad)
if ctx.needs_input_grad[1]:
Expand Down
10 changes: 8 additions & 2 deletions pyzag/stochastic.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,13 @@
import torch

import pyro
import pyro.nn.module
from pyro.nn import PyroSample
import pyro.distributions as dist
import pyro.poutine.scale_messenger
import pyro.poutine.mask_messenger

from typing import cast


class MapNormal:
Expand Down Expand Up @@ -179,13 +184,14 @@ def forward(self, *args, results=None, weights=None, **kwargs):
)

if weights is None:
weights = torch.ones(shape[-1], device=self.eps.device)
weights = torch.ones(shape[-1], device=cast(torch.Tensor, self.eps).device)

# Rather annoying that this is necessary, this is not a no-op as it tells pyro that these
# are *not* batched over the number of samples
_ = self._sample_top()

# Same here
eps = None
if self.sample_noise_outside:
eps = self.eps

Expand All @@ -195,7 +201,7 @@ def forward(self, *args, results=None, weights=None, **kwargs):
), pyro.poutine.scale_messenger.ScaleMessenger(
scale=weights
), pyro.poutine.mask_messenger.MaskMessenger(
mask=self.mask
mask=torch.BoolTensor(self.mask)
):
self._sample_bot()
res = self.base(*args, **kwargs)
Expand Down
Loading