From e3d626b796f44ae073667791ea3fb7504e79e789 Mon Sep 17 00:00:00 2001 From: Linux-cpp-lisp <1473644+Linux-cpp-lisp@users.noreply.github.com> Date: Tue, 21 Dec 2021 22:19:13 -0500 Subject: [PATCH 1/7] add padding util --- tests/test_util.py | 18 ++++++++++++++++++ torch_runstats/_util.py | 9 +++++++++ 2 files changed, 27 insertions(+) create mode 100644 tests/test_util.py create mode 100644 torch_runstats/_util.py diff --git a/tests/test_util.py b/tests/test_util.py new file mode 100644 index 0000000..705c31e --- /dev/null +++ b/tests/test_util.py @@ -0,0 +1,18 @@ +import pytest +import random + +import torch + +from torch_runstats._util import _pad_dim0 + + +@pytest.mark.parametrize("ndim", [1, 2, 4]) +def test_pad_dim0(ndim): + orig_shape = tuple(random.randint(1, 5) for _ in range(ndim)) + x = torch.ones(orig_shape) + to_add = 3 + padded = _pad_dim0(x, to_add) + assert padded.shape[1:] == orig_shape[1:] + assert padded.shape[0] == orig_shape[0] + to_add + assert torch.equal(x, padded[:-to_add]) + assert padded[-to_add:].abs().max() == 0 diff --git a/torch_runstats/_util.py b/torch_runstats/_util.py new file mode 100644 index 0000000..702b53d --- /dev/null +++ b/torch_runstats/_util.py @@ -0,0 +1,9 @@ +import torch + + +def _pad_dim0(x: torch.Tensor, n: int) -> torch.Tensor: + if n == 0: + return + elif n < 0: + raise ValueError + return torch.nn.functional.pad(x, (0,) * ((x.ndim - 1) * 2) + (0, n)) From 3f30cafd471d97a644a4d579af044ae5f8f6a55f Mon Sep 17 00:00:00 2001 From: Linux-cpp-lisp <1473644+Linux-cpp-lisp@users.noreply.github.com> Date: Tue, 21 Dec 2021 22:30:03 -0500 Subject: [PATCH 2/7] get/set state --- CHANGELOG.md | 2 ++ tests/test_stats.py | 41 +++++++++++++++++++++ torch_runstats/_runstats.py | 72 ++++++++++++++++++++++++++----------- 3 files changed, 94 insertions(+), 21 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 261c4e4..cde76ff 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 Most recent change on the bottom. ## [Unreleased] +### Added + - `get_state` and `set_state` methods ## [0.2.0] - 2021-11-22 diff --git a/tests/test_stats.py b/tests/test_stats.py index 17e1456..d50cc5f 100644 --- a/tests/test_stats.py +++ b/tests/test_stats.py @@ -163,6 +163,47 @@ def test_batching(do_accumulate_by, nan_attrs, allclose): runstats.reset(reset_n_bins=True) +@pytest.mark.parametrize( + "dim,reduce_dims", + [ + (1, tuple()), + (1, (0,)), + (3, tuple()), + (3, (0,)), + ((2, 3), tuple()), + (torch.Size((1, 2, 1)), tuple()), + (torch.Size((1, 2, 1)), (1,)), + (torch.Size((3, 2, 4)), (0, 2)), + (torch.Size((3, 2, 4)), (0, 1, 2)), + ], +) +@pytest.mark.parametrize("do_accumulate_by", [True, False]) +@pytest.mark.parametrize("reduction", [Reduction.MEAN, Reduction.RMS]) +def test_state(dim, reduce_dims, do_accumulate_by, reduction, allclose): + truth_obj = StatsTruth(dim=dim, reduction=reduction, reduce_dims=reduce_dims) + runstats1, runstats2 = [ + RunningStats(dim=dim, reduction=reduction, reduce_dims=reduce_dims) + for _ in range(2) + ] + batch1, batch2 = [ + torch.randn((random.randint(1, 10),) + runstats1.dim) for _ in range(2) + ] + if do_accumulate_by: + acc_by1, acc_by2 = [ + torch.randint(0, random.randint(1, 5), size=(batch.shape[0],)) + for batch in (batch1, batch2) + ] + else: + acc_by1, acc_by2 = None, None + runstats1.accumulate_batch(batch1, accumulate_by=acc_by1) + runstats2.accumulate_batch(batch2, accumulate_by=acc_by2) + res1, res2 = runstats1.current_result(), runstats2.current_result() + # now, load the state of 2 -> 1 + runstats1.set_state(runstats2.get_state()) + # should be the same since moved the state + assert torch.allclose(runstats1.current_result(), res2) + + @pytest.mark.parametrize("reduction", [Reduction.MEAN, Reduction.RMS]) def test_zeros(reduction, allclose): dim = (4,) diff --git a/torch_runstats/_runstats.py b/torch_runstats/_runstats.py index de7886e..ac63eac 100644 --- a/torch_runstats/_runstats.py +++ b/torch_runstats/_runstats.py @@ -6,6 +6,7 @@ import torch from .scatter import scatter +from ._util import _pad_dim0 def _prod(x): @@ -236,28 +237,10 @@ def accumulate_batch( # do we need new bins? N_to_add = new_sum.shape[0] - self._n_bins if N_to_add > 0: - - # time to expand - self._state = torch.cat( - ( - self._state, - self._state.new_zeros((N_to_add,) + self._state.shape[1:]), - ), - dim=0, - ) - self._n = torch.cat( - (self._n, self._n.new_zeros((N_to_add,) + self._n.shape[1:])), dim=0 - ) - - # assert self._state.shape == (self._n_bins + N_to_add,) + self._dim - self._n_bins += N_to_add - + self._expand_state(N_to_add) elif N_to_add < 0: - - new_sum = torch.cat( - (new_sum, new_sum.new_zeros((-N_to_add,) + new_sum.shape[1:])), dim=0 - ) - N = torch.cat((N, N.new_zeros((-N_to_add,) + N.shape[1:])), dim=0) + new_sum = _pad_dim0(new_sum, -N_to_add) + N = _pad_dim0(N, -N_to_add) self._state += (new_sum - N * self._state) / (self._n + N) self._n += N @@ -305,6 +288,35 @@ def current_result(self) -> torch.Tensor: elif self._reduction == Reduction.RMS: return self._state.sqrt() + def get_state(self) -> Tuple[torch.Tensor, ...]: + """Get the current internal state of the object for later use. + + The contents of this tuple of tensors has no gueranteed format and should + only be used within a program and with ``RunningStats`` objects that were + constructed with exactly identical parameters. The format of the result + is NOT gueranteed to be consistant across versions and should not be + serialized. + + The returned tensors are copies of the internal state and are safe to + mutate. + + Returns: + a tuple of tensors. + """ + return tuple(t.clone() for t in (self._n, self._state)) + + def set_state(self, state: Tuple[torch.Tensor, ...]) -> None: + """Set the internal state of this object to ``state`` from an earlier call to ``get_state``. + + Args: + state: an internal state of a ``RunningStats`` object of identical parameters retreived + with calling its ``get_state`` method. + """ + n, state = state + self._n_bins = len(n) + self._n = n.to(dtype=self._n.dtype, device=self._n.device) + self._state = state.to(dtype=self._state.dtype, device=self._state.device) + @property def n(self) -> torch.Tensor: """The number of samples processed so far in each bin. @@ -338,3 +350,21 @@ def reduce_dims(self) -> Tuple[int, ...]: def reduction(self) -> Reduction: """The reduction computed by this object.""" return self._reduction + + def _expand_state(self, N_to_add: int) -> None: + if N_to_add == 0: + return + elif N_to_add < 0: + raise ValueError + # time to expand + self._state = torch.cat( + ( + self._state, + self._state.new_zeros((N_to_add,) + self._state.shape[1:]), + ), + dim=0, + ) + self._n = torch.cat( + (self._n, self._n.new_zeros((N_to_add,) + self._n.shape[1:])), dim=0 + ) + self._n_bins += N_to_add From 0e23af56c9e8b250e195ab7797cdf954ac35636c Mon Sep 17 00:00:00 2001 From: Linux-cpp-lisp <1473644+Linux-cpp-lisp@users.noreply.github.com> Date: Tue, 21 Dec 2021 22:34:16 -0500 Subject: [PATCH 3/7] add accumulate_state --- CHANGELOG.md | 3 ++- tests/test_stats.py | 45 +++++++++++++++++++++++++++++++++++-- torch_runstats/_runstats.py | 15 +++++++++++++ 3 files changed, 60 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index cde76ff..858d04c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,7 +8,8 @@ Most recent change on the bottom. ## [Unreleased] ### Added - - `get_state` and `set_state` methods + - Added `get_state` and `set_state` methods + - Added `accumulate_state` method ## [0.2.0] - 2021-11-22 diff --git a/tests/test_stats.py b/tests/test_stats.py index d50cc5f..a74a7e8 100644 --- a/tests/test_stats.py +++ b/tests/test_stats.py @@ -180,7 +180,6 @@ def test_batching(do_accumulate_by, nan_attrs, allclose): @pytest.mark.parametrize("do_accumulate_by", [True, False]) @pytest.mark.parametrize("reduction", [Reduction.MEAN, Reduction.RMS]) def test_state(dim, reduce_dims, do_accumulate_by, reduction, allclose): - truth_obj = StatsTruth(dim=dim, reduction=reduction, reduce_dims=reduce_dims) runstats1, runstats2 = [ RunningStats(dim=dim, reduction=reduction, reduce_dims=reduce_dims) for _ in range(2) @@ -201,7 +200,49 @@ def test_state(dim, reduce_dims, do_accumulate_by, reduction, allclose): # now, load the state of 2 -> 1 runstats1.set_state(runstats2.get_state()) # should be the same since moved the state - assert torch.allclose(runstats1.current_result(), res2) + assert allclose(runstats1.current_result(), res2) + + +@pytest.mark.parametrize( + "dim,reduce_dims", + [ + (1, tuple()), + (1, (0,)), + (3, tuple()), + (3, (0,)), + ((2, 3), tuple()), + (torch.Size((1, 2, 1)), tuple()), + (torch.Size((1, 2, 1)), (1,)), + (torch.Size((3, 2, 4)), (0, 2)), + (torch.Size((3, 2, 4)), (0, 1, 2)), + ], +) +@pytest.mark.parametrize("do_accumulate_by", [True, False]) +@pytest.mark.parametrize("reduction", [Reduction.MEAN, Reduction.RMS]) +def test_accumulate_state(dim, reduce_dims, do_accumulate_by, reduction, allclose): + runstats1, runstats2, runstats3 = [ + RunningStats(dim=dim, reduction=reduction, reduce_dims=reduce_dims) + for _ in range(3) + ] + batch1, batch2 = [ + torch.randn((random.randint(1, 10),) + runstats1.dim) for _ in range(2) + ] + if do_accumulate_by: + acc_by1, acc_by2 = [ + torch.randint(0, random.randint(1, 5), size=(batch.shape[0],)) + for batch in (batch1, batch2) + ] + else: + acc_by1, acc_by2 = None, None + runstats1.accumulate_batch(batch1, accumulate_by=acc_by1) + runstats2.accumulate_batch(batch2, accumulate_by=acc_by2) + # now accumulate batch2 into runstats1 through the state + runstats1.accumulate_state(runstats2.get_state()) + # and make a truth baseline + runstats3.accumulate_batch(batch1, accumulate_by=acc_by1) + runstats3.accumulate_batch(batch2, accumulate_by=acc_by2) + # and check: + assert allclose(runstats1.current_result(), runstats3.current_result()) @pytest.mark.parametrize("reduction", [Reduction.MEAN, Reduction.RMS]) diff --git a/torch_runstats/_runstats.py b/torch_runstats/_runstats.py index ac63eac..a3eac54 100644 --- a/torch_runstats/_runstats.py +++ b/torch_runstats/_runstats.py @@ -317,6 +317,21 @@ def set_state(self, state: Tuple[torch.Tensor, ...]) -> None: self._n = n.to(dtype=self._n.dtype, device=self._n.device) self._state = state.to(dtype=self._state.dtype, device=self._state.device) + def accumulate_state(self, state: Tuple[torch.Tensor, ...]) -> None: + """ """ + n, state = state + N_to_add = len(n) - self.n_bins + if N_to_add > 0: + self._expand_state(N_to_add) + elif N_to_add < 0: + # need to expand the parameter state + n = _pad_dim0(n, -N_to_add) + state = _pad_dim0(state, -N_to_add) + self._state += n * (state - self._state) / (self._n + n) + self._n += n + # Make div by zero 0 + self._state = torch.nan_to_num_(self._state, nan=0.0) + @property def n(self) -> torch.Tensor: """The number of samples processed so far in each bin. From 518cf16be0df429d972936771b44d7e9af4939cc Mon Sep 17 00:00:00 2001 From: Linux-cpp-lisp <1473644+Linux-cpp-lisp@users.noreply.github.com> Date: Fri, 21 Jan 2022 14:50:09 -0500 Subject: [PATCH 4/7] lint --- tests/test_stats.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_stats.py b/tests/test_stats.py index a74a7e8..6dddc9c 100644 --- a/tests/test_stats.py +++ b/tests/test_stats.py @@ -196,7 +196,7 @@ def test_state(dim, reduce_dims, do_accumulate_by, reduction, allclose): acc_by1, acc_by2 = None, None runstats1.accumulate_batch(batch1, accumulate_by=acc_by1) runstats2.accumulate_batch(batch2, accumulate_by=acc_by2) - res1, res2 = runstats1.current_result(), runstats2.current_result() + _, res2 = runstats1.current_result(), runstats2.current_result() # now, load the state of 2 -> 1 runstats1.set_state(runstats2.get_state()) # should be the same since moved the state From 9501c360b9f5fbb1ad3a30fcbd8410b45f3f5fd4 Mon Sep 17 00:00:00 2001 From: Linux-cpp-lisp <1473644+Linux-cpp-lisp@users.noreply.github.com> Date: Fri, 21 Jan 2022 14:50:23 -0500 Subject: [PATCH 5/7] bump --- torch_runstats/_version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_runstats/_version.py b/torch_runstats/_version.py index d3ec452..3ced358 100644 --- a/torch_runstats/_version.py +++ b/torch_runstats/_version.py @@ -1 +1 @@ -__version__ = "0.2.0" +__version__ = "0.2.1" From 34751e829c93e28fc102773c7d87646f3c540339 Mon Sep 17 00:00:00 2001 From: JonathanSchmidt1 Date: Mon, 20 Mar 2023 17:47:34 +0100 Subject: [PATCH 6/7] set device in _runstats accumulate state --- torch_runstats/_runstats.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/torch_runstats/_runstats.py b/torch_runstats/_runstats.py index a3eac54..1f48fc2 100644 --- a/torch_runstats/_runstats.py +++ b/torch_runstats/_runstats.py @@ -327,8 +327,10 @@ def accumulate_state(self, state: Tuple[torch.Tensor, ...]) -> None: # need to expand the parameter state n = _pad_dim0(n, -N_to_add) state = _pad_dim0(state, -N_to_add) - self._state += n * (state - self._state) / (self._n + n) - self._n += n + + device = self._state.device + self._state += n.to(device) * (state.to(device) - self._state) / (self._n + n.to(device)) + self._n += n.to(device) # Make div by zero 0 self._state = torch.nan_to_num_(self._state, nan=0.0) From ac333efdbf2c1cbcbbb752bbb49171fa70510324 Mon Sep 17 00:00:00 2001 From: Linux-cpp-lisp <1473644+Linux-cpp-lisp@users.noreply.github.com> Date: Mon, 20 Mar 2023 17:56:43 -0400 Subject: [PATCH 7/7] change device before padding in case --- torch_runstats/_runstats.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/torch_runstats/_runstats.py b/torch_runstats/_runstats.py index 1f48fc2..343cfc3 100644 --- a/torch_runstats/_runstats.py +++ b/torch_runstats/_runstats.py @@ -319,7 +319,9 @@ def set_state(self, state: Tuple[torch.Tensor, ...]) -> None: def accumulate_state(self, state: Tuple[torch.Tensor, ...]) -> None: """ """ - n, state = state + device = self._state.device + n, state = [e.to(device) for e in state] + N_to_add = len(n) - self.n_bins if N_to_add > 0: self._expand_state(N_to_add) @@ -328,9 +330,8 @@ def accumulate_state(self, state: Tuple[torch.Tensor, ...]) -> None: n = _pad_dim0(n, -N_to_add) state = _pad_dim0(state, -N_to_add) - device = self._state.device - self._state += n.to(device) * (state.to(device) - self._state) / (self._n + n.to(device)) - self._n += n.to(device) + self._state += n * (state - self._state) / (self._n + n) + self._n += n # Make div by zero 0 self._state = torch.nan_to_num_(self._state, nan=0.0)