From e47086148ec1fa246cdd8befeb5aa529941dd81e Mon Sep 17 00:00:00 2001 From: Mihir Patel Date: Tue, 15 Nov 2022 14:52:45 -0800 Subject: [PATCH] Add inf eval check (#1733) * add inf eval check * add test --- composer/trainer/trainer.py | 11 ++++++++++ tests/trainer/test_trainer_eval.py | 34 +++++++++++++++++++++++++++++- 2 files changed, 44 insertions(+), 1 deletion(-) diff --git a/composer/trainer/trainer.py b/composer/trainer/trainer.py index 98094f73de..aa96ff6f38 100644 --- a/composer/trainer/trainer.py +++ b/composer/trainer/trainer.py @@ -5,6 +5,7 @@ from __future__ import annotations +import collections.abc import contextlib import datetime import itertools @@ -147,6 +148,16 @@ def _set_evaluator_interval_and_subset_num_batches( evaluator.subset_num_batches = subset_num_batches if evaluator.eval_interval is None: evaluator.eval_interval = eval_interval + eval_dataloader = evaluator.dataloader.dataloader + if isinstance(eval_dataloader, collections.abc.Sized) and evaluator.subset_num_batches is None: + try: + dataloader_len = len(eval_dataloader) + except TypeError: + dataloader_len = None + if dataloader_len == None: + raise ValueError('eval_subset_num_batches must be set when using an infinite sized ' + 'eval_dataloader where length is `None`. Otherwise, evaluation will ' + 'run forever and never terminate.') def _is_auto_grad_accum(grad_accum: Union[int, str], device: Device): diff --git a/tests/trainer/test_trainer_eval.py b/tests/trainer/test_trainer_eval.py index cf87431ebb..ce9c1bb3fd 100644 --- a/tests/trainer/test_trainer_eval.py +++ b/tests/trainer/test_trainer_eval.py @@ -1,7 +1,8 @@ # Copyright 2022 MosaicML Composer authors # SPDX-License-Identifier: Apache-2.0 -from typing import Callable, Union +import contextlib +from typing import Callable, Optional, Union import pytest from torch.utils.data import DataLoader @@ -301,3 +302,34 @@ def test_eval_params_evaluator(): assert event_counter_callback.event_to_num_calls[Event.EVAL_START] == trainer.state.timestamp.batch assert event_counter_callback.event_to_num_calls[ Event.EVAL_BATCH_START] == eval_subset_num_batches * trainer.state.timestamp.batch + + +class InfiniteDataloader(DataLoader): + """Infinite dataloader that never raises StopIteration.""" + + def __iter__(self): + while True: + for batch in super().__iter__(): + yield batch + + def __len__(self) -> Optional[int]: + return None + + +@pytest.mark.parametrize('eval_subset_num_batches', [None, 1]) +def test_infinite_eval_dataloader(eval_subset_num_batches): + """Test the `eval_subset_num_batches` is required with infinite dataloader.""" + # Construct the trainer + train_dataset = RandomClassificationDataset() + train_dataloader = DataLoader(train_dataset, sampler=dist.get_sampler(train_dataset)) + eval_dataset = RandomClassificationDataset() + eval_dataloader = InfiniteDataloader(eval_dataset, sampler=dist.get_sampler(eval_dataset)) + + with contextlib.nullcontext() if eval_subset_num_batches is not None else pytest.raises(ValueError): + Trainer( + model=SimpleModel(), + train_dataloader=train_dataloader, + eval_dataloader=eval_dataloader, + max_duration='1ep', + eval_subset_num_batches=eval_subset_num_batches, + )