From 9e2763e097e08a42caa9b7e03a88b6fff38621a5 Mon Sep 17 00:00:00 2001 From: vfdev Date: Tue, 3 Dec 2024 11:38:46 +0100 Subject: [PATCH 1/4] Update requirements-dev.txt (#3310) --- requirements-dev.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements-dev.txt b/requirements-dev.txt index d475e556cdf..91b560e5653 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -21,7 +21,7 @@ mlflow neptune-client>=0.16.17 tensorboard torchvision -pynvml +pynvml<12 # pynvml module was removed in 12.X, is not developed or maintained. We should replace pynvml with something else. clearml scikit-image py-rouge From 1c3b9e975073bd4be47533fe98adf537b2ea67b4 Mon Sep 17 00:00:00 2001 From: vfdev Date: Tue, 3 Dec 2024 13:07:30 +0100 Subject: [PATCH 2/4] Fixed GPU tests exec scripts and failing metrics (#3301) * Fixed GPU tests and failing metrics * Updated timeout param * Updated infra cuda12.1 -> cuda12.4 * Add tmate for debug * Disable sudo * Attempt to debug tmate! * Attempt to use bash in step * Update gpu-tests.yml * Skip failing test and remove tmate debugging * Fixed formatting --------- Co-authored-by: Sadra Barikbin --- .github/workflows/gpu-tests.yml | 2 +- .../clustering/calinski_harabasz_score.py | 4 ++-- .../metrics/clustering/davies_bouldin_score.py | 4 ++-- ignite/metrics/clustering/silhouette_score.py | 4 ++-- .../metrics/regression/kendall_correlation.py | 4 ++-- .../metrics/regression/spearman_correlation.py | 4 ++-- tests/common_test_functionality.sh | 5 ++--- .../metrics/test_classification_report.py | 17 +++++++++++++++++ tests/ignite/metrics/test_hsic.py | 4 ++-- tests/run_cpu_tests.sh | 5 ++--- tests/run_gpu_tests.sh | 10 +++++----- 11 files changed, 39 insertions(+), 24 deletions(-) diff --git a/.github/workflows/gpu-tests.yml b/.github/workflows/gpu-tests.yml index 0a72711fbdd..13c628ad302 100644 --- a/.github/workflows/gpu-tests.yml +++ b/.github/workflows/gpu-tests.yml @@ -124,7 +124,7 @@ jobs: uses: nick-fields/retry@v2.9.0 with: max_attempts: 5 - timeout_minutes: 25 + timeout_minutes: 45 shell: bash command: docker exec -t pthd /bin/bash -xec 'bash tests/run_gpu_tests.sh 2' new_command_on_retry: docker exec -e USE_LAST_FAILED=1 -t pthd /bin/bash -xec 'bash tests/run_gpu_tests.sh 2' diff --git a/ignite/metrics/clustering/calinski_harabasz_score.py b/ignite/metrics/clustering/calinski_harabasz_score.py index fe58ac46151..79f8dc99ba5 100644 --- a/ignite/metrics/clustering/calinski_harabasz_score.py +++ b/ignite/metrics/clustering/calinski_harabasz_score.py @@ -11,8 +11,8 @@ def _calinski_harabasz_score(features: Tensor, labels: Tensor) -> float: from sklearn.metrics import calinski_harabasz_score - np_features = features.numpy() - np_labels = labels.numpy() + np_features = features.cpu().numpy() + np_labels = labels.cpu().numpy() score = calinski_harabasz_score(np_features, np_labels) return score diff --git a/ignite/metrics/clustering/davies_bouldin_score.py b/ignite/metrics/clustering/davies_bouldin_score.py index b34ec69f51a..afea0518951 100644 --- a/ignite/metrics/clustering/davies_bouldin_score.py +++ b/ignite/metrics/clustering/davies_bouldin_score.py @@ -11,8 +11,8 @@ def _davies_bouldin_score(features: Tensor, labels: Tensor) -> float: from sklearn.metrics import davies_bouldin_score - np_features = features.numpy() - np_labels = labels.numpy() + np_features = features.cpu().numpy() + np_labels = labels.cpu().numpy() score = davies_bouldin_score(np_features, np_labels) return score diff --git a/ignite/metrics/clustering/silhouette_score.py b/ignite/metrics/clustering/silhouette_score.py index 39b28c5d040..48a59d583ec 100644 --- a/ignite/metrics/clustering/silhouette_score.py +++ b/ignite/metrics/clustering/silhouette_score.py @@ -111,7 +111,7 @@ def __init__( def _silhouette_score(self, features: Tensor, labels: Tensor) -> float: from sklearn.metrics import silhouette_score - np_features = features.numpy() - np_labels = labels.numpy() + np_features = features.cpu().numpy() + np_labels = labels.cpu().numpy() score = silhouette_score(np_features, np_labels, **self._silhouette_kwargs) return score diff --git a/ignite/metrics/regression/kendall_correlation.py b/ignite/metrics/regression/kendall_correlation.py index 7ad87b22402..34d876a3659 100644 --- a/ignite/metrics/regression/kendall_correlation.py +++ b/ignite/metrics/regression/kendall_correlation.py @@ -16,8 +16,8 @@ def _get_kendall_tau(variant: str = "b") -> Callable[[Tensor, Tensor], float]: raise ValueError(f"variant accepts 'b' or 'c', got {variant!r}.") def _tau(predictions: Tensor, targets: Tensor) -> float: - np_preds = predictions.flatten().numpy() - np_targets = targets.flatten().numpy() + np_preds = predictions.flatten().cpu().numpy() + np_targets = targets.flatten().cpu().numpy() r = kendalltau(np_preds, np_targets, variant=variant).statistic return r diff --git a/ignite/metrics/regression/spearman_correlation.py b/ignite/metrics/regression/spearman_correlation.py index 7f126d6e56b..cbd89f67c9d 100644 --- a/ignite/metrics/regression/spearman_correlation.py +++ b/ignite/metrics/regression/spearman_correlation.py @@ -12,8 +12,8 @@ def _spearman_r(predictions: Tensor, targets: Tensor) -> float: from scipy.stats import spearmanr - np_preds = predictions.flatten().numpy() - np_targets = targets.flatten().numpy() + np_preds = predictions.flatten().cpu().numpy() + np_targets = targets.flatten().cpu().numpy() r = spearmanr(np_preds, np_targets).statistic return r diff --git a/tests/common_test_functionality.sh b/tests/common_test_functionality.sh index 6e60947f927..91003eddc09 100644 --- a/tests/common_test_functionality.sh +++ b/tests/common_test_functionality.sh @@ -85,7 +85,6 @@ run_tests() { skip_distrib_opt="" fi - echo [pytest] > pytest.ini ; echo "cache_dir=${cache_dir}" >> pytest.ini # Assemble options for the pytest command @@ -103,8 +102,8 @@ run_tests() { # Run the command if [ "$trap_deselected_exit_code" -eq "1" ]; then - CUDA_VISIBLE_DEVICES="" eval "pytest ${pytest_args}" || { exit_code=$?; if [ "$exit_code" -eq ${last_failed_no_failures_code} ]; then echo "All tests deselected"; else exit $exit_code; fi; } + eval "pytest ${pytest_args}" || { exit_code=$?; if [ "$exit_code" -eq ${last_failed_no_failures_code} ]; then echo "All tests deselected"; else exit $exit_code; fi; } else - CUDA_VISIBLE_DEVICES="" eval "pytest ${pytest_args}" + eval "pytest ${pytest_args}" fi } diff --git a/tests/ignite/metrics/test_classification_report.py b/tests/ignite/metrics/test_classification_report.py index 87e328c8051..cae8b5145f5 100644 --- a/tests/ignite/metrics/test_classification_report.py +++ b/tests/ignite/metrics/test_classification_report.py @@ -164,6 +164,23 @@ def update(engine, i): @pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Skip if no GPU") @pytest.mark.skipif(Version(torch.__version__) < Version("1.7.0"), reason="Skip if < 1.7.0") def test_distrib_nccl_gpu(distributed_context_single_node_nccl): + + pytest.skip("Temporarily skip failing test. See https://github.com/pytorch/ignite/pull/3301") + # When run with 2 devices: + # tests/ignite/metrics/test_classification_report.py::test_distrib_nccl_gpu Fatal Python error: Aborted + # Thread 0x00007fac95c95700 (most recent call first): + # + + # Thread 0x00007facbb89b700 (most recent call first): + # + + # Thread 0x00007fae637f4700 (most recent call first): + # File "", line 534 in read + # File "", line 567 in from_io + # File "", line 1160 in _thread_receiver + # File "", line 341 in run + # File "", line 411 in _perform_spawn + device = idist.device() _test_integration_multiclass(device, True) _test_integration_multiclass(device, False) diff --git a/tests/ignite/metrics/test_hsic.py b/tests/ignite/metrics/test_hsic.py index 57af5fa2862..28fe5c1f97d 100644 --- a/tests/ignite/metrics/test_hsic.py +++ b/tests/ignite/metrics/test_hsic.py @@ -139,10 +139,10 @@ def test_integration(self, sigma_x: float, sigma_y: float): metric_devices.append(device) for metric_device in metric_devices: - x = torch.randn((n_iters * batch_size, n_dims_x)).float().to(device) + x = torch.randn((n_iters * batch_size, n_dims_x), device=device).float() lin = nn.Linear(n_dims_x, n_dims_y).to(device) - y = torch.sin(lin(x) * 100) + torch.randn(n_iters * batch_size, n_dims_y) * 1e-4 + y = torch.sin(lin(x) * 100) + torch.randn(n_iters * batch_size, n_dims_y, device=x.device) * 1e-4 def data_loader(i, input_x, input_y): return input_x[i * batch_size : (i + 1) * batch_size], input_y[i * batch_size : (i + 1) * batch_size] diff --git a/tests/run_cpu_tests.sh b/tests/run_cpu_tests.sh index 8d387f5542e..f52988a6818 100644 --- a/tests/run_cpu_tests.sh +++ b/tests/run_cpu_tests.sh @@ -6,8 +6,7 @@ skip_distrib_tests=${SKIP_DISTRIB_TESTS:-0} use_last_failed=${USE_LAST_FAILED:-0} match_tests_expression=${1:-""} - -run_tests \ +CUDA_VISIBLE_DEVICES="" run_tests \ --core_args "--tx 4*popen//python=python -vvv tests/ignite" \ --cache_dir ".cpu-not-distrib" \ --skip_distrib_tests "${skip_distrib_tests}" \ @@ -21,7 +20,7 @@ if [ "${skip_distrib_tests}" -eq "1" ]; then fi # Run 2 processes with --dist=each -run_tests \ +CUDA_VISIBLE_DEVICES="" run_tests \ --core_args "-m distributed -vvv tests/ignite" \ --world_size 2 \ --cache_dir ".cpu-distrib" \ diff --git a/tests/run_gpu_tests.sh b/tests/run_gpu_tests.sh index 26497f19c83..c86d1d0746e 100644 --- a/tests/run_gpu_tests.sh +++ b/tests/run_gpu_tests.sh @@ -2,26 +2,26 @@ source "$(dirname "$0")/common_test_functionality.sh" set -xeu -skip_distrib_tests=${SKIP_DISTRIB_TESTS:-1} +# https://pubs.opengroup.org/onlinepubs/009695399/utilities/xcu_chap02.html#tag_02_06_02 +skip_distrib_tests=${SKIP_DISTRIB_TESTS:-0} use_last_failed=${USE_LAST_FAILED:-0} ngpus=${1:-1} match_tests_expression=${2:-""} if [ -z "$match_tests_expression" ]; then - cuda_pattern="cuda" + cuda_pattern="cuda or nccl or gloo" else - cuda_pattern="cuda and $match_tests_expression" + cuda_pattern="(cuda or nccl or gloo) and $match_tests_expression" fi run_tests \ - --core_args "-vvv tests/ignite" \ + --core_args "-vvv tests/ignite -m 'not distributed'" \ --cache_dir ".gpu-cuda" \ --skip_distrib_tests "${skip_distrib_tests}" \ --use_coverage 1 \ --match_tests_expression "${cuda_pattern}" \ --use_last_failed ${use_last_failed} -# https://pubs.opengroup.org/onlinepubs/009695399/utilities/xcu_chap02.html#tag_02_06_02 if [ "${skip_distrib_tests}" -eq "1" ]; then exit 0 fi From 4f462109858291f3499e8fab809d91e69a9d9532 Mon Sep 17 00:00:00 2001 From: vfdev Date: Tue, 3 Dec 2024 13:39:54 +0100 Subject: [PATCH 3/4] Updated GpuInfo metric, pynvml<12 (#3311) --- ignite/contrib/engines/common.py | 2 +- ignite/metrics/gpu_info.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/ignite/contrib/engines/common.py b/ignite/contrib/engines/common.py index 09f769a18d0..bcfa54be55e 100644 --- a/ignite/contrib/engines/common.py +++ b/ignite/contrib/engines/common.py @@ -78,7 +78,7 @@ def setup_common_training_handlers( lr_scheduler: learning rate scheduler as native torch LRScheduler or ignite's parameter scheduler. with_gpu_stats: if True, :class:`~ignite.metrics.GpuInfo` is attached to the - trainer. This requires `pynvml` package to be installed. + trainer. This requires `pynvml<12` package to be installed. output_names: list of names associated with `update_function` output dictionary. with_pbars: if True, two progress bars on epochs and optionally on iterations are attached. Default, True. diff --git a/ignite/metrics/gpu_info.py b/ignite/metrics/gpu_info.py index 96ed4f07c57..d13bbd8a1da 100644 --- a/ignite/metrics/gpu_info.py +++ b/ignite/metrics/gpu_info.py @@ -10,7 +10,7 @@ class GpuInfo(Metric): """Provides GPU information: a) used memory percentage, b) gpu utilization percentage values as Metric - on each iterations. + on each iterations. This metric requires `pynvml `_ package of version `<12`. .. Note :: @@ -39,7 +39,7 @@ def __init__(self) -> None: except ImportError: raise ModuleNotFoundError( "This contrib module requires pynvml to be installed. " - "Please install it with command: \n pip install pynvml" + "Please install it with command: \n pip install 'pynvml<12'" ) # Let's check available devices if not torch.cuda.is_available(): From 6f8ad2a16b2d82fd6b2b83b849b86160fe8a8b6a Mon Sep 17 00:00:00 2001 From: Fabio Bonassi Date: Tue, 3 Dec 2024 17:28:22 +0100 Subject: [PATCH 4/4] =?UTF-8?q?Give=20the=20option=20to=20terminate=20the?= =?UTF-8?q?=20engine=20without=20firing=20Events.COMPLET=E2=80=A6=20(#3309?= =?UTF-8?q?)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Give the option to terminate the engine without firing Events.COMPLETED. The default behaviour is not changed. Note that even though Events.COMPLETED is not fired, its timer is updated. * Update ignite/engine/engine.py Co-authored-by: vfdev * Update ignite/engine/engine.py Co-authored-by: vfdev * Update ignite/engine/engine.py Co-authored-by: vfdev * Update ignite/engine/engine.py Co-authored-by: vfdev * Update ignite/engine/events.py Co-authored-by: vfdev * Argument `skip_event_completed` renamed to `skip_completed` * - Fixed docs broken links. - Do not update self.state.times[Events.COMPLETED.name] if terminated - Fixed unit test * Update ignite/engine/engine.py Co-authored-by: vfdev * Refactoring and patching. - Engine time logging moved out of the if clause. In the log message "completed" has been replaced with "finished" to avoid confusion. - Same changes applied to the method `_internal_run_legacy()` * Restored .gitignore Sorry for accidentally including it into the previous commit! * Update ignite/engine/events.py * Fixed typo in test_engine.py * Parametrized test for engine.terminate(skip_completed) * Update event table * Fixed documentation --------- Co-authored-by: vfdev --- ignite/engine/engine.py | 41 +++++++++++++------ ignite/engine/events.py | 25 ++++++++---- tests/ignite/contrib/engines/test_common.py | 1 - tests/ignite/engine/test_engine.py | 45 ++++++++++++++------- 4 files changed, 76 insertions(+), 36 deletions(-) diff --git a/ignite/engine/engine.py b/ignite/engine/engine.py index 27a949cacca..e2a14898607 100644 --- a/ignite/engine/engine.py +++ b/ignite/engine/engine.py @@ -140,6 +140,7 @@ def __init__(self, process_function: Callable[["Engine", Any], Any]): self._process_function = process_function self.last_event_name: Optional[Events] = None self.should_terminate = False + self.skip_completed_after_termination = False self.should_terminate_single_epoch = False self.should_interrupt = False self.state = State() @@ -538,7 +539,7 @@ def call_interrupt(): self.logger.info("interrupt signaled. Engine will interrupt the run after current iteration is finished.") self.should_interrupt = True - def terminate(self) -> None: + def terminate(self, skip_completed: bool = False) -> None: """Sends terminate signal to the engine, so that it terminates completely the run. The run is terminated after the event on which ``terminate`` method was called. The following events are triggered: @@ -547,6 +548,9 @@ def terminate(self) -> None: - :attr:`~ignite.engine.events.Events.TERMINATE` - :attr:`~ignite.engine.events.Events.COMPLETED` + Args: + skip_completed: if True, the event :attr:`~ignite.engine.events.Events.COMPLETED` is not fired after + :attr:`~ignite.engine.events.Events.TERMINATE`. Default is False. Examples: .. testcode:: @@ -617,9 +621,12 @@ def terminate(): .. versionchanged:: 0.4.10 Behaviour changed, for details see https://github.com/pytorch/ignite/issues/2669 + .. versionchanged:: 0.5.2 + Added `skip_completed` flag """ self.logger.info("Terminate signaled. Engine will stop after current iteration is finished.") self.should_terminate = True + self.skip_completed_after_termination = skip_completed def terminate_epoch(self) -> None: """Sends terminate signal to the engine, so that it terminates the current epoch. The run @@ -993,13 +1000,17 @@ def _internal_run_as_gen(self) -> Generator[Any, None, State]: time_taken = time.time() - start_time # time is available for handlers but must be updated after fire self.state.times[Events.COMPLETED.name] = time_taken - handlers_start_time = time.time() - self._fire_event(Events.COMPLETED) - time_taken += time.time() - handlers_start_time - # update time wrt handlers - self.state.times[Events.COMPLETED.name] = time_taken + + # do not fire Events.COMPLETED if we terminated the run with flag `skip_completed=True` + if not (self.should_terminate and self.skip_completed_after_termination): + handlers_start_time = time.time() + self._fire_event(Events.COMPLETED) + time_taken += time.time() - handlers_start_time + # update time wrt handlers + self.state.times[Events.COMPLETED.name] = time_taken + hours, mins, secs = _to_hours_mins_secs(time_taken) - self.logger.info(f"Engine run complete. Time taken: {hours:02d}:{mins:02d}:{secs:06.3f}") + self.logger.info(f"Engine run finished. Time taken: {hours:02d}:{mins:02d}:{secs:06.3f}") except BaseException as e: self._dataloader_iter = None @@ -1174,13 +1185,17 @@ def _internal_run_legacy(self) -> State: time_taken = time.time() - start_time # time is available for handlers but must be updated after fire self.state.times[Events.COMPLETED.name] = time_taken - handlers_start_time = time.time() - self._fire_event(Events.COMPLETED) - time_taken += time.time() - handlers_start_time - # update time wrt handlers - self.state.times[Events.COMPLETED.name] = time_taken + + # do not fire Events.COMPLETED if we terminated the run with flag `skip_completed=True` + if not (self.should_terminate and self.skip_completed_after_termination): + handlers_start_time = time.time() + self._fire_event(Events.COMPLETED) + time_taken += time.time() - handlers_start_time + # update time wrt handlers + self.state.times[Events.COMPLETED.name] = time_taken + hours, mins, secs = _to_hours_mins_secs(time_taken) - self.logger.info(f"Engine run complete. Time taken: {hours:02d}:{mins:02d}:{secs:06.3f}") + self.logger.info(f"Engine run finished. Time taken: {hours:02d}:{mins:02d}:{secs:06.3f}") except BaseException as e: self._dataloader_iter = None diff --git a/ignite/engine/events.py b/ignite/engine/events.py index 9dd99348492..87622d3415c 100644 --- a/ignite/engine/events.py +++ b/ignite/engine/events.py @@ -259,36 +259,47 @@ class Events(EventEnum): - TERMINATE_SINGLE_EPOCH : triggered when the run is about to end the current epoch, after receiving a :meth:`~ignite.engine.engine.Engine.terminate_epoch()` or :meth:`~ignite.engine.engine.Engine.terminate()` call. + - EPOCH_COMPLETED : triggered when the epoch is ended. Note that this is triggered even + when :meth:`~ignite.engine.engine.Engine.terminate_epoch()` is called. - TERMINATE : triggered when the run is about to end completely, after receiving :meth:`~ignite.engine.engine.Engine.terminate()` call. - - EPOCH_COMPLETED : triggered when the epoch is ended. Note that this is triggered even - when :meth:`~ignite.engine.engine.Engine.terminate_epoch()` is called. - - COMPLETED : triggered when engine's run is completed + - COMPLETED : triggered when engine's run is completed or terminated with + :meth:`~ignite.engine.engine.Engine.terminate()`, unless the flag + `skip_completed` is set to True. The table below illustrates which events are triggered when various termination methods are called. .. list-table:: - :widths: 24 25 33 18 + :widths: 35 38 28 20 20 :header-rows: 1 * - Method - - EVENT_COMPLETED - TERMINATE_SINGLE_EPOCH + - EPOCH_COMPLETED - TERMINATE + - COMPLETED * - no termination - - ✔ - ✗ + - ✔ - ✗ + - ✔ * - :meth:`~ignite.engine.engine.Engine.terminate_epoch()` - ✔ - ✔ - ✗ + - ✔ * - :meth:`~ignite.engine.engine.Engine.terminate()` - ✗ - ✔ - ✔ + - ✔ + * - :meth:`~ignite.engine.engine.Engine.terminate()` with `skip_completed=True` + - ✗ + - ✔ + - ✔ + - ✗ Since v0.3.0, Events become more flexible and allow to pass an event filter to the Engine: @@ -357,7 +368,7 @@ class CustomEvents(EventEnum): STARTED = "started" """triggered when engine's run is started.""" COMPLETED = "completed" - """triggered when engine's run is completed""" + """triggered when engine's run is completed, or after receiving terminate() call.""" ITERATION_STARTED = "iteration_started" """triggered when an iteration is started.""" diff --git a/tests/ignite/contrib/engines/test_common.py b/tests/ignite/contrib/engines/test_common.py index d0100be9e8d..e14042e62c1 100644 --- a/tests/ignite/contrib/engines/test_common.py +++ b/tests/ignite/contrib/engines/test_common.py @@ -8,7 +8,6 @@ from torch.utils.data.distributed import DistributedSampler import ignite.distributed as idist - import ignite.handlers as handlers from ignite.contrib.engines.common import ( _setup_logging, diff --git a/tests/ignite/engine/test_engine.py b/tests/ignite/engine/test_engine.py index 13021242650..fcb0299aa22 100644 --- a/tests/ignite/engine/test_engine.py +++ b/tests/ignite/engine/test_engine.py @@ -40,11 +40,14 @@ class TestEngine: def set_interrupt_resume_enabled(self, interrupt_resume_enabled): Engine.interrupt_resume_enabled = interrupt_resume_enabled - def test_terminate(self): + @pytest.mark.parametrize("skip_completed", [True, False]) + def test_terminate(self, skip_completed): engine = Engine(lambda e, b: 1) assert not engine.should_terminate - engine.terminate() + assert not engine.skip_completed_after_termination + engine.terminate(skip_completed) assert engine.should_terminate + assert engine.skip_completed_after_termination == skip_completed def test_invalid_process_raises_with_invalid_signature(self): with pytest.raises(ValueError, match=r"Engine must be given a processing function in order to run"): @@ -236,25 +239,32 @@ def check_iter_and_data(): assert num_calls_check_iter_epoch == 1 @pytest.mark.parametrize( - "terminate_event, e, i", + "terminate_event, e, i, skip_completed", [ - (Events.STARTED, 0, 0), - (Events.EPOCH_STARTED(once=2), 2, None), - (Events.EPOCH_COMPLETED(once=2), 2, None), - (Events.GET_BATCH_STARTED(once=12), None, 12), - (Events.GET_BATCH_COMPLETED(once=12), None, 12), - (Events.ITERATION_STARTED(once=14), None, 14), - (Events.ITERATION_COMPLETED(once=14), None, 14), + (Events.STARTED, 0, 0, True), + (Events.EPOCH_STARTED(once=2), 2, None, True), + (Events.EPOCH_COMPLETED(once=2), 2, None, True), + (Events.GET_BATCH_STARTED(once=12), None, 12, True), + (Events.GET_BATCH_COMPLETED(once=12), None, 12, False), + (Events.ITERATION_STARTED(once=14), None, 14, True), + (Events.ITERATION_COMPLETED(once=14), None, 14, True), + (Events.STARTED, 0, 0, False), + (Events.EPOCH_STARTED(once=2), 2, None, False), + (Events.EPOCH_COMPLETED(once=2), 2, None, False), + (Events.GET_BATCH_STARTED(once=12), None, 12, False), + (Events.GET_BATCH_COMPLETED(once=12), None, 12, False), + (Events.ITERATION_STARTED(once=14), None, 14, False), + (Events.ITERATION_COMPLETED(once=14), None, 14, False), ], ) - def test_terminate_events_sequence(self, terminate_event, e, i): + def test_terminate_events_sequence(self, terminate_event, e, i, skip_completed): engine = RecordedEngine(MagicMock(return_value=1)) data = range(10) max_epochs = 5 @engine.on(terminate_event) def call_terminate(): - engine.terminate() + engine.terminate(skip_completed) @engine.on(Events.EXCEPTION_RAISED) def assert_no_exceptions(ee): @@ -271,10 +281,15 @@ def assert_no_exceptions(ee): if e is None: e = i // len(data) + 1 + if skip_completed: + assert engine.called_events[-1] == (e, i, Events.TERMINATE) + assert engine.called_events[-2] == (e, i, terminate_event) + else: + assert engine.called_events[-1] == (e, i, Events.COMPLETED) + assert engine.called_events[-2] == (e, i, Events.TERMINATE) + assert engine.called_events[-3] == (e, i, terminate_event) + assert engine.called_events[0] == (0, 0, Events.STARTED) - assert engine.called_events[-1] == (e, i, Events.COMPLETED) - assert engine.called_events[-2] == (e, i, Events.TERMINATE) - assert engine.called_events[-3] == (e, i, terminate_event) assert engine._dataloader_iter is None @pytest.mark.parametrize("data, epoch_length", [(None, 10), (range(10), None)])