Skip to content

Commit

Permalink
Add inf eval check (#1733)
Browse files Browse the repository at this point in the history
* add inf eval check

* add test
  • Loading branch information
mvpatel2000 authored and Bandish Shah committed Nov 15, 2022
1 parent 1722c61 commit e470861
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 1 deletion.
11 changes: 11 additions & 0 deletions composer/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from __future__ import annotations

import collections.abc
import contextlib
import datetime
import itertools
Expand Down Expand Up @@ -147,6 +148,16 @@ def _set_evaluator_interval_and_subset_num_batches(
evaluator.subset_num_batches = subset_num_batches
if evaluator.eval_interval is None:
evaluator.eval_interval = eval_interval
eval_dataloader = evaluator.dataloader.dataloader
if isinstance(eval_dataloader, collections.abc.Sized) and evaluator.subset_num_batches is None:
try:
dataloader_len = len(eval_dataloader)
except TypeError:
dataloader_len = None
if dataloader_len == None:
raise ValueError('eval_subset_num_batches must be set when using an infinite sized '
'eval_dataloader where length is `None`. Otherwise, evaluation will '
'run forever and never terminate.')


def _is_auto_grad_accum(grad_accum: Union[int, str], device: Device):
Expand Down
34 changes: 33 additions & 1 deletion tests/trainer/test_trainer_eval.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# Copyright 2022 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0

from typing import Callable, Union
import contextlib
from typing import Callable, Optional, Union

import pytest
from torch.utils.data import DataLoader
Expand Down Expand Up @@ -301,3 +302,34 @@ def test_eval_params_evaluator():
assert event_counter_callback.event_to_num_calls[Event.EVAL_START] == trainer.state.timestamp.batch
assert event_counter_callback.event_to_num_calls[
Event.EVAL_BATCH_START] == eval_subset_num_batches * trainer.state.timestamp.batch


class InfiniteDataloader(DataLoader):
"""Infinite dataloader that never raises StopIteration."""

def __iter__(self):
while True:
for batch in super().__iter__():
yield batch

def __len__(self) -> Optional[int]:
return None


@pytest.mark.parametrize('eval_subset_num_batches', [None, 1])
def test_infinite_eval_dataloader(eval_subset_num_batches):
"""Test the `eval_subset_num_batches` is required with infinite dataloader."""
# Construct the trainer
train_dataset = RandomClassificationDataset()
train_dataloader = DataLoader(train_dataset, sampler=dist.get_sampler(train_dataset))
eval_dataset = RandomClassificationDataset()
eval_dataloader = InfiniteDataloader(eval_dataset, sampler=dist.get_sampler(eval_dataset))

with contextlib.nullcontext() if eval_subset_num_batches is not None else pytest.raises(ValueError):
Trainer(
model=SimpleModel(),
train_dataloader=train_dataloader,
eval_dataloader=eval_dataloader,
max_duration='1ep',
eval_subset_num_batches=eval_subset_num_batches,
)

0 comments on commit e470861

Please sign in to comment.