diff --git a/pyzag/chunktime.py b/pyzag/chunktime.py index 88edfef..3a5c89c 100644 --- a/pyzag/chunktime.py +++ b/pyzag/chunktime.py @@ -25,12 +25,12 @@ # pylint: disable=abstract-method """ - Functions and objects to help with blocked/chunked time integration. +Functions and objects to help with blocked/chunked time integration. - These include: - 1. Sparse matrix classes for banded systems - 2. General sparse matrix classes - 3. Specialized solver routines working with banded systems +These include: + 1. Sparse matrix classes for banded systems + 2. General sparse matrix classes + 3. Specialized solver routines working with banded systems """ import warnings @@ -40,8 +40,18 @@ from torch.nn.functional import pad import numpy as np +from abc import ABC, abstractmethod -class ChunkNewtonRaphson: + +class NonlinearSolver(ABC): + """Base class for nonlinear solvers working with chunked operators""" + + @abstractmethod + def solve(self, fn, x0) -> torch.Tensor: + pass + + +class ChunkNewtonRaphson(NonlinearSolver): """Solve a nonlinear system with Newton's method where the residual and Jacobian are presented as chunked operators Keyword Args: @@ -203,6 +213,8 @@ def step(self, x, J, fn, R0, take_step): f = torch.ones_like(nR0) + R, nR = None, None + for _ in range(self.linesearch_iter): x[:, final_steps] = x0 - f.unsqueeze(-1).unsqueeze(0) * dx @@ -217,6 +229,7 @@ def step(self, x, J, fn, R0, take_step): f = torch.where(decreasing, f, f * self.alpha) + assert R is not None and nR is not None return x, R, J, nR @@ -286,6 +299,10 @@ def shape(self): """ return (self.sbat, self.n, self.n) + @abstractmethod + def matvec(self, v) -> torch.Tensor: + pass + class LUFactorization(BidiagonalOperator): """A factorization that uses the LU decomposition of A @@ -601,7 +618,12 @@ class BidiagonalForwardOperator(BidiagonalOperator): storing the nblk-1 off diagonal blocks """ - def __init__(self, *args, inverse_operator=BidiagonalThomasFactorization, **kwargs): + def __init__( + self, + *args, + inverse_operator: type[BidiagonalOperator] = BidiagonalThomasFactorization, + **kwargs, + ): super().__init__(*args, **kwargs) self.inverse_operator = inverse_operator @@ -621,7 +643,7 @@ def forward(self, v): """ return self.matvec(v) - def matvec(self, v): + def matvec(self, v) -> torch.Tensor: """ :math:`A \\cdot v` in an efficient manner diff --git a/pyzag/nonlinear.py b/pyzag/nonlinear.py index e647df4..18bdc12 100644 --- a/pyzag/nonlinear.py +++ b/pyzag/nonlinear.py @@ -27,6 +27,7 @@ import torch from pyzag import chunktime +from abc import ABC, abstractmethod class NonlinearRecursiveFunction(torch.nn.Module): @@ -124,7 +125,15 @@ def predict(self, results, k, kinc): return self.history[k : k + kinc] -class ZeroPredictor: +class Predictor(ABC): + """Base class for predictors""" + + @abstractmethod + def predict(self, results, k, kinc) -> torch.Tensor: + pass + + +class ZeroPredictor(Predictor): """Predict steps just using zeros""" def predict(self, results, k, kinc): @@ -138,7 +147,7 @@ def predict(self, results, k, kinc): return torch.zeros_like(results[k : k + kinc]) -class PreviousStepsPredictor: +class PreviousStepsPredictor(Predictor): """Predict by providing the values from the previous chunk of steps steps""" def predict(self, results, k, kinc): @@ -158,7 +167,7 @@ def predict(self, results, k, kinc): return results[(k - kinc) : k] -class LastStepPredictor: +class LastStepPredictor(Predictor): """Predict by providing the values from the previous single step""" def predict(self, results, k, kinc): @@ -175,7 +184,7 @@ def predict(self, results, k, kinc): return results[k - 1].unsqueeze(0).expand((kinc,) + results.shape[1:]) -class StepExtrapolatingPredictor: +class StepExtrapolatingPredictor(Predictor): """Predict by extrapolating using the previous *chunks* of steps""" def predict(self, results, k, kinc): @@ -196,7 +205,7 @@ def predict(self, results, k, kinc): return dinc.unsqueeze(0).expand((kinc,) + results.shape[1:]) -class ExtrapolatingPredictor: +class ExtrapolatingPredictor(Predictor): """Predict by extrapolating the values from the previous *single* steps""" def predict(self, results, k, kinc): @@ -342,10 +351,12 @@ class RecursiveNonlinearEquationSolver(torch.nn.Module): def __init__( self, func, - step_generator=StepGenerator(1), - predictor=ZeroPredictor(), - direct_solve_operator=chunktime.BidiagonalThomasFactorization, - nonlinear_solver=chunktime.ChunkNewtonRaphson(), + step_generator: StepGenerator = StepGenerator(1), + predictor: Predictor = ZeroPredictor(), + direct_solve_operator: type[ + chunktime.BidiagonalOperator + ] = chunktime.BidiagonalThomasFactorization, + nonlinear_solver: chunktime.NonlinearSolver = chunktime.ChunkNewtonRaphson(), callbacks=None, convert_nan_gradients=True, ): @@ -457,6 +468,10 @@ def rewind(self, output_grad): ) # Loop backwards through time + if self.result is None: + raise ValueError("No cached result found for adjoint rewind") + + adjoint = None for k1, k2 in self.step_generator(len(self.result)).reverse(): # Get our block of the results with torch.enable_grad(): @@ -479,6 +494,11 @@ def rewind(self, output_grad): grad_result, adjoint, R[:1], retain=True ) + if adjoint is None: + raise NotImplementedError( + "Currently only supports adjoint rewind for a single block" + ) + # Do the block adjoint update adjoint = self.block_update_adjoint( J, output_grad[k1:k2].flip(0), adjoint[-1] @@ -488,6 +508,8 @@ def rewind(self, output_grad): with torch.enable_grad(): grad_result = self.accumulate(grad_result, adjoint, R[1:]) + if adjoint is None: + raise ValueError("No adjoint values calculated during rewind") if self.convert_nan_gradients: return tuple(torch.nan_to_num(g) for g in grad_result), torch.nan_to_num( adjoint[-1] @@ -570,7 +592,7 @@ def forward(ctx, solver, y0, n, forces, *params): return y @staticmethod - def backward(ctx, output_grad): + def backward(ctx, *output_grad): with torch.no_grad(): grad_res, adj_last = ctx.solver.rewind(output_grad) if ctx.needs_input_grad[1]: diff --git a/pyzag/stochastic.py b/pyzag/stochastic.py index 7452b11..6c697d6 100644 --- a/pyzag/stochastic.py +++ b/pyzag/stochastic.py @@ -27,8 +27,13 @@ import torch import pyro +import pyro.nn.module from pyro.nn import PyroSample import pyro.distributions as dist +import pyro.poutine.scale_messenger +import pyro.poutine.mask_messenger + +from typing import cast class MapNormal: @@ -179,13 +184,14 @@ def forward(self, *args, results=None, weights=None, **kwargs): ) if weights is None: - weights = torch.ones(shape[-1], device=self.eps.device) + weights = torch.ones(shape[-1], device=cast(torch.Tensor, self.eps).device) # Rather annoying that this is necessary, this is not a no-op as it tells pyro that these # are *not* batched over the number of samples _ = self._sample_top() # Same here + eps = None if self.sample_noise_outside: eps = self.eps @@ -195,7 +201,7 @@ def forward(self, *args, results=None, weights=None, **kwargs): ), pyro.poutine.scale_messenger.ScaleMessenger( scale=weights ), pyro.poutine.mask_messenger.MaskMessenger( - mask=self.mask + mask=torch.BoolTensor(self.mask) ): self._sample_bot() res = self.base(*args, **kwargs)