Skip to content

Commit

Permalink
fix: TorchWriter.reset() closes SummaryWriter (#5375)
Browse files Browse the repository at this point in the history
next write will reopen the writers automatically so the net
effect of the change is that new files are created for the first
write after each reset.
  • Loading branch information
mpkouznetsov authored Nov 3, 2022
1 parent f9d8cfe commit 282ddc0
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 2 deletions.
4 changes: 2 additions & 2 deletions harness/determined/tensorboard/metric_writers/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,5 +44,5 @@ def add_scalar(self, name: str, value: Union[int, float, np.number], step: int)
self.writer.add_scalar(name, value, step)

def reset(self) -> None:
if "flush" in dir(self.writer):
self.writer.flush()
# flush AND close the writer so that the next attempt to write will create a new file
self.writer.close()
22 changes: 22 additions & 0 deletions harness/tests/tensorboard/test_torch_writer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import pathlib
from typing import Any, Dict

from _pytest import monkeypatch

from determined import tensorboard
from determined.tensorboard.metric_writers import pytorch


def test_torch_writer(monkeypatch: monkeypatch.MonkeyPatch, tmp_path: pathlib.Path) -> None:
def mock_get_base_path(dummy: Dict[str, Any]) -> pathlib.Path:
return tmp_path

monkeypatch.setattr(tensorboard, "get_base_path", mock_get_base_path)
logger = pytorch.TorchWriter()
logger.add_scalar("foo", 7, 0)
logger.reset()
logger.add_scalar("foo", 8, 1)
logger.reset()

files = list(tmp_path.iterdir())
assert len(files) == 2

0 comments on commit 282ddc0

Please sign in to comment.