diff --git a/CHANGELOG.md b/CHANGELOG.md index 261c4e4..858d04c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 Most recent change on the bottom. ## [Unreleased] +### Added + - Added `get_state` and `set_state` methods + - Added `accumulate_state` method ## [0.2.0] - 2021-11-22 diff --git a/tests/test_stats.py b/tests/test_stats.py index 17e1456..6dddc9c 100644 --- a/tests/test_stats.py +++ b/tests/test_stats.py @@ -163,6 +163,88 @@ def test_batching(do_accumulate_by, nan_attrs, allclose): runstats.reset(reset_n_bins=True) +@pytest.mark.parametrize( + "dim,reduce_dims", + [ + (1, tuple()), + (1, (0,)), + (3, tuple()), + (3, (0,)), + ((2, 3), tuple()), + (torch.Size((1, 2, 1)), tuple()), + (torch.Size((1, 2, 1)), (1,)), + (torch.Size((3, 2, 4)), (0, 2)), + (torch.Size((3, 2, 4)), (0, 1, 2)), + ], +) +@pytest.mark.parametrize("do_accumulate_by", [True, False]) +@pytest.mark.parametrize("reduction", [Reduction.MEAN, Reduction.RMS]) +def test_state(dim, reduce_dims, do_accumulate_by, reduction, allclose): + runstats1, runstats2 = [ + RunningStats(dim=dim, reduction=reduction, reduce_dims=reduce_dims) + for _ in range(2) + ] + batch1, batch2 = [ + torch.randn((random.randint(1, 10),) + runstats1.dim) for _ in range(2) + ] + if do_accumulate_by: + acc_by1, acc_by2 = [ + torch.randint(0, random.randint(1, 5), size=(batch.shape[0],)) + for batch in (batch1, batch2) + ] + else: + acc_by1, acc_by2 = None, None + runstats1.accumulate_batch(batch1, accumulate_by=acc_by1) + runstats2.accumulate_batch(batch2, accumulate_by=acc_by2) + _, res2 = runstats1.current_result(), runstats2.current_result() + # now, load the state of 2 -> 1 + runstats1.set_state(runstats2.get_state()) + # should be the same since moved the state + assert allclose(runstats1.current_result(), res2) + + +@pytest.mark.parametrize( + "dim,reduce_dims", + [ + (1, tuple()), + (1, (0,)), + (3, tuple()), + (3, (0,)), + ((2, 3), tuple()), + (torch.Size((1, 2, 1)), tuple()), + (torch.Size((1, 2, 1)), (1,)), + (torch.Size((3, 2, 4)), (0, 2)), + (torch.Size((3, 2, 4)), (0, 1, 2)), + ], +) +@pytest.mark.parametrize("do_accumulate_by", [True, False]) +@pytest.mark.parametrize("reduction", [Reduction.MEAN, Reduction.RMS]) +def test_accumulate_state(dim, reduce_dims, do_accumulate_by, reduction, allclose): + runstats1, runstats2, runstats3 = [ + RunningStats(dim=dim, reduction=reduction, reduce_dims=reduce_dims) + for _ in range(3) + ] + batch1, batch2 = [ + torch.randn((random.randint(1, 10),) + runstats1.dim) for _ in range(2) + ] + if do_accumulate_by: + acc_by1, acc_by2 = [ + torch.randint(0, random.randint(1, 5), size=(batch.shape[0],)) + for batch in (batch1, batch2) + ] + else: + acc_by1, acc_by2 = None, None + runstats1.accumulate_batch(batch1, accumulate_by=acc_by1) + runstats2.accumulate_batch(batch2, accumulate_by=acc_by2) + # now accumulate batch2 into runstats1 through the state + runstats1.accumulate_state(runstats2.get_state()) + # and make a truth baseline + runstats3.accumulate_batch(batch1, accumulate_by=acc_by1) + runstats3.accumulate_batch(batch2, accumulate_by=acc_by2) + # and check: + assert allclose(runstats1.current_result(), runstats3.current_result()) + + @pytest.mark.parametrize("reduction", [Reduction.MEAN, Reduction.RMS]) def test_zeros(reduction, allclose): dim = (4,) diff --git a/tests/test_util.py b/tests/test_util.py new file mode 100644 index 0000000..705c31e --- /dev/null +++ b/tests/test_util.py @@ -0,0 +1,18 @@ +import pytest +import random + +import torch + +from torch_runstats._util import _pad_dim0 + + +@pytest.mark.parametrize("ndim", [1, 2, 4]) +def test_pad_dim0(ndim): + orig_shape = tuple(random.randint(1, 5) for _ in range(ndim)) + x = torch.ones(orig_shape) + to_add = 3 + padded = _pad_dim0(x, to_add) + assert padded.shape[1:] == orig_shape[1:] + assert padded.shape[0] == orig_shape[0] + to_add + assert torch.equal(x, padded[:-to_add]) + assert padded[-to_add:].abs().max() == 0 diff --git a/torch_runstats/_runstats.py b/torch_runstats/_runstats.py index de7886e..343cfc3 100644 --- a/torch_runstats/_runstats.py +++ b/torch_runstats/_runstats.py @@ -6,6 +6,7 @@ import torch from .scatter import scatter +from ._util import _pad_dim0 def _prod(x): @@ -236,28 +237,10 @@ def accumulate_batch( # do we need new bins? N_to_add = new_sum.shape[0] - self._n_bins if N_to_add > 0: - - # time to expand - self._state = torch.cat( - ( - self._state, - self._state.new_zeros((N_to_add,) + self._state.shape[1:]), - ), - dim=0, - ) - self._n = torch.cat( - (self._n, self._n.new_zeros((N_to_add,) + self._n.shape[1:])), dim=0 - ) - - # assert self._state.shape == (self._n_bins + N_to_add,) + self._dim - self._n_bins += N_to_add - + self._expand_state(N_to_add) elif N_to_add < 0: - - new_sum = torch.cat( - (new_sum, new_sum.new_zeros((-N_to_add,) + new_sum.shape[1:])), dim=0 - ) - N = torch.cat((N, N.new_zeros((-N_to_add,) + N.shape[1:])), dim=0) + new_sum = _pad_dim0(new_sum, -N_to_add) + N = _pad_dim0(N, -N_to_add) self._state += (new_sum - N * self._state) / (self._n + N) self._n += N @@ -305,6 +288,53 @@ def current_result(self) -> torch.Tensor: elif self._reduction == Reduction.RMS: return self._state.sqrt() + def get_state(self) -> Tuple[torch.Tensor, ...]: + """Get the current internal state of the object for later use. + + The contents of this tuple of tensors has no gueranteed format and should + only be used within a program and with ``RunningStats`` objects that were + constructed with exactly identical parameters. The format of the result + is NOT gueranteed to be consistant across versions and should not be + serialized. + + The returned tensors are copies of the internal state and are safe to + mutate. + + Returns: + a tuple of tensors. + """ + return tuple(t.clone() for t in (self._n, self._state)) + + def set_state(self, state: Tuple[torch.Tensor, ...]) -> None: + """Set the internal state of this object to ``state`` from an earlier call to ``get_state``. + + Args: + state: an internal state of a ``RunningStats`` object of identical parameters retreived + with calling its ``get_state`` method. + """ + n, state = state + self._n_bins = len(n) + self._n = n.to(dtype=self._n.dtype, device=self._n.device) + self._state = state.to(dtype=self._state.dtype, device=self._state.device) + + def accumulate_state(self, state: Tuple[torch.Tensor, ...]) -> None: + """ """ + device = self._state.device + n, state = [e.to(device) for e in state] + + N_to_add = len(n) - self.n_bins + if N_to_add > 0: + self._expand_state(N_to_add) + elif N_to_add < 0: + # need to expand the parameter state + n = _pad_dim0(n, -N_to_add) + state = _pad_dim0(state, -N_to_add) + + self._state += n * (state - self._state) / (self._n + n) + self._n += n + # Make div by zero 0 + self._state = torch.nan_to_num_(self._state, nan=0.0) + @property def n(self) -> torch.Tensor: """The number of samples processed so far in each bin. @@ -338,3 +368,21 @@ def reduce_dims(self) -> Tuple[int, ...]: def reduction(self) -> Reduction: """The reduction computed by this object.""" return self._reduction + + def _expand_state(self, N_to_add: int) -> None: + if N_to_add == 0: + return + elif N_to_add < 0: + raise ValueError + # time to expand + self._state = torch.cat( + ( + self._state, + self._state.new_zeros((N_to_add,) + self._state.shape[1:]), + ), + dim=0, + ) + self._n = torch.cat( + (self._n, self._n.new_zeros((N_to_add,) + self._n.shape[1:])), dim=0 + ) + self._n_bins += N_to_add diff --git a/torch_runstats/_util.py b/torch_runstats/_util.py new file mode 100644 index 0000000..702b53d --- /dev/null +++ b/torch_runstats/_util.py @@ -0,0 +1,9 @@ +import torch + + +def _pad_dim0(x: torch.Tensor, n: int) -> torch.Tensor: + if n == 0: + return + elif n < 0: + raise ValueError + return torch.nn.functional.pad(x, (0,) * ((x.ndim - 1) * 2) + (0, n)) diff --git a/torch_runstats/_version.py b/torch_runstats/_version.py index d3ec452..3ced358 100644 --- a/torch_runstats/_version.py +++ b/torch_runstats/_version.py @@ -1 +1 @@ -__version__ = "0.2.0" +__version__ = "0.2.1"