diff --git a/.github/workflows/gpu-tests.yml b/.github/workflows/gpu-tests.yml
index 0a72711fbddc..13c628ad302c 100644
--- a/.github/workflows/gpu-tests.yml
+++ b/.github/workflows/gpu-tests.yml
@@ -124,7 +124,7 @@ jobs:
uses: nick-fields/retry@v2.9.0
with:
max_attempts: 5
- timeout_minutes: 25
+ timeout_minutes: 45
shell: bash
command: docker exec -t pthd /bin/bash -xec 'bash tests/run_gpu_tests.sh 2'
new_command_on_retry: docker exec -e USE_LAST_FAILED=1 -t pthd /bin/bash -xec 'bash tests/run_gpu_tests.sh 2'
diff --git a/ignite/contrib/engines/common.py b/ignite/contrib/engines/common.py
index 09f769a18d0f..bcfa54be55ea 100644
--- a/ignite/contrib/engines/common.py
+++ b/ignite/contrib/engines/common.py
@@ -78,7 +78,7 @@ def setup_common_training_handlers(
lr_scheduler: learning rate scheduler
as native torch LRScheduler or ignite's parameter scheduler.
with_gpu_stats: if True, :class:`~ignite.metrics.GpuInfo` is attached to the
- trainer. This requires `pynvml` package to be installed.
+ trainer. This requires `pynvml<12` package to be installed.
output_names: list of names associated with `update_function` output dictionary.
with_pbars: if True, two progress bars on epochs and optionally on iterations are attached.
Default, True.
diff --git a/ignite/engine/engine.py b/ignite/engine/engine.py
index 27a949cacca2..e2a148986075 100644
--- a/ignite/engine/engine.py
+++ b/ignite/engine/engine.py
@@ -140,6 +140,7 @@ def __init__(self, process_function: Callable[["Engine", Any], Any]):
self._process_function = process_function
self.last_event_name: Optional[Events] = None
self.should_terminate = False
+ self.skip_completed_after_termination = False
self.should_terminate_single_epoch = False
self.should_interrupt = False
self.state = State()
@@ -538,7 +539,7 @@ def call_interrupt():
self.logger.info("interrupt signaled. Engine will interrupt the run after current iteration is finished.")
self.should_interrupt = True
- def terminate(self) -> None:
+ def terminate(self, skip_completed: bool = False) -> None:
"""Sends terminate signal to the engine, so that it terminates completely the run. The run is
terminated after the event on which ``terminate`` method was called. The following events are triggered:
@@ -547,6 +548,9 @@ def terminate(self) -> None:
- :attr:`~ignite.engine.events.Events.TERMINATE`
- :attr:`~ignite.engine.events.Events.COMPLETED`
+ Args:
+ skip_completed: if True, the event :attr:`~ignite.engine.events.Events.COMPLETED` is not fired after
+ :attr:`~ignite.engine.events.Events.TERMINATE`. Default is False.
Examples:
.. testcode::
@@ -617,9 +621,12 @@ def terminate():
.. versionchanged:: 0.4.10
Behaviour changed, for details see https://github.com/pytorch/ignite/issues/2669
+ .. versionchanged:: 0.5.2
+ Added `skip_completed` flag
"""
self.logger.info("Terminate signaled. Engine will stop after current iteration is finished.")
self.should_terminate = True
+ self.skip_completed_after_termination = skip_completed
def terminate_epoch(self) -> None:
"""Sends terminate signal to the engine, so that it terminates the current epoch. The run
@@ -993,13 +1000,17 @@ def _internal_run_as_gen(self) -> Generator[Any, None, State]:
time_taken = time.time() - start_time
# time is available for handlers but must be updated after fire
self.state.times[Events.COMPLETED.name] = time_taken
- handlers_start_time = time.time()
- self._fire_event(Events.COMPLETED)
- time_taken += time.time() - handlers_start_time
- # update time wrt handlers
- self.state.times[Events.COMPLETED.name] = time_taken
+
+ # do not fire Events.COMPLETED if we terminated the run with flag `skip_completed=True`
+ if not (self.should_terminate and self.skip_completed_after_termination):
+ handlers_start_time = time.time()
+ self._fire_event(Events.COMPLETED)
+ time_taken += time.time() - handlers_start_time
+ # update time wrt handlers
+ self.state.times[Events.COMPLETED.name] = time_taken
+
hours, mins, secs = _to_hours_mins_secs(time_taken)
- self.logger.info(f"Engine run complete. Time taken: {hours:02d}:{mins:02d}:{secs:06.3f}")
+ self.logger.info(f"Engine run finished. Time taken: {hours:02d}:{mins:02d}:{secs:06.3f}")
except BaseException as e:
self._dataloader_iter = None
@@ -1174,13 +1185,17 @@ def _internal_run_legacy(self) -> State:
time_taken = time.time() - start_time
# time is available for handlers but must be updated after fire
self.state.times[Events.COMPLETED.name] = time_taken
- handlers_start_time = time.time()
- self._fire_event(Events.COMPLETED)
- time_taken += time.time() - handlers_start_time
- # update time wrt handlers
- self.state.times[Events.COMPLETED.name] = time_taken
+
+ # do not fire Events.COMPLETED if we terminated the run with flag `skip_completed=True`
+ if not (self.should_terminate and self.skip_completed_after_termination):
+ handlers_start_time = time.time()
+ self._fire_event(Events.COMPLETED)
+ time_taken += time.time() - handlers_start_time
+ # update time wrt handlers
+ self.state.times[Events.COMPLETED.name] = time_taken
+
hours, mins, secs = _to_hours_mins_secs(time_taken)
- self.logger.info(f"Engine run complete. Time taken: {hours:02d}:{mins:02d}:{secs:06.3f}")
+ self.logger.info(f"Engine run finished. Time taken: {hours:02d}:{mins:02d}:{secs:06.3f}")
except BaseException as e:
self._dataloader_iter = None
diff --git a/ignite/engine/events.py b/ignite/engine/events.py
index 9dd99348492b..87622d3415cc 100644
--- a/ignite/engine/events.py
+++ b/ignite/engine/events.py
@@ -259,36 +259,47 @@ class Events(EventEnum):
- TERMINATE_SINGLE_EPOCH : triggered when the run is about to end the current epoch,
after receiving a :meth:`~ignite.engine.engine.Engine.terminate_epoch()` or
:meth:`~ignite.engine.engine.Engine.terminate()` call.
+ - EPOCH_COMPLETED : triggered when the epoch is ended. Note that this is triggered even
+ when :meth:`~ignite.engine.engine.Engine.terminate_epoch()` is called.
- TERMINATE : triggered when the run is about to end completely,
after receiving :meth:`~ignite.engine.engine.Engine.terminate()` call.
- - EPOCH_COMPLETED : triggered when the epoch is ended. Note that this is triggered even
- when :meth:`~ignite.engine.engine.Engine.terminate_epoch()` is called.
- - COMPLETED : triggered when engine's run is completed
+ - COMPLETED : triggered when engine's run is completed or terminated with
+ :meth:`~ignite.engine.engine.Engine.terminate()`, unless the flag
+ `skip_completed` is set to True.
The table below illustrates which events are triggered when various termination methods are called.
.. list-table::
- :widths: 24 25 33 18
+ :widths: 35 38 28 20 20
:header-rows: 1
* - Method
- - EVENT_COMPLETED
- TERMINATE_SINGLE_EPOCH
+ - EPOCH_COMPLETED
- TERMINATE
+ - COMPLETED
* - no termination
- - ✔
- ✗
+ - ✔
- ✗
+ - ✔
* - :meth:`~ignite.engine.engine.Engine.terminate_epoch()`
- ✔
- ✔
- ✗
+ - ✔
* - :meth:`~ignite.engine.engine.Engine.terminate()`
- ✗
- ✔
- ✔
+ - ✔
+ * - :meth:`~ignite.engine.engine.Engine.terminate()` with `skip_completed=True`
+ - ✗
+ - ✔
+ - ✔
+ - ✗
Since v0.3.0, Events become more flexible and allow to pass an event filter to the Engine:
@@ -357,7 +368,7 @@ class CustomEvents(EventEnum):
STARTED = "started"
"""triggered when engine's run is started."""
COMPLETED = "completed"
- """triggered when engine's run is completed"""
+ """triggered when engine's run is completed, or after receiving terminate() call."""
ITERATION_STARTED = "iteration_started"
"""triggered when an iteration is started."""
diff --git a/ignite/metrics/clustering/calinski_harabasz_score.py b/ignite/metrics/clustering/calinski_harabasz_score.py
index fe58ac461517..79f8dc99ba50 100644
--- a/ignite/metrics/clustering/calinski_harabasz_score.py
+++ b/ignite/metrics/clustering/calinski_harabasz_score.py
@@ -11,8 +11,8 @@
def _calinski_harabasz_score(features: Tensor, labels: Tensor) -> float:
from sklearn.metrics import calinski_harabasz_score
- np_features = features.numpy()
- np_labels = labels.numpy()
+ np_features = features.cpu().numpy()
+ np_labels = labels.cpu().numpy()
score = calinski_harabasz_score(np_features, np_labels)
return score
diff --git a/ignite/metrics/clustering/davies_bouldin_score.py b/ignite/metrics/clustering/davies_bouldin_score.py
index b34ec69f51ad..afea0518951b 100644
--- a/ignite/metrics/clustering/davies_bouldin_score.py
+++ b/ignite/metrics/clustering/davies_bouldin_score.py
@@ -11,8 +11,8 @@
def _davies_bouldin_score(features: Tensor, labels: Tensor) -> float:
from sklearn.metrics import davies_bouldin_score
- np_features = features.numpy()
- np_labels = labels.numpy()
+ np_features = features.cpu().numpy()
+ np_labels = labels.cpu().numpy()
score = davies_bouldin_score(np_features, np_labels)
return score
diff --git a/ignite/metrics/clustering/silhouette_score.py b/ignite/metrics/clustering/silhouette_score.py
index 39b28c5d0409..48a59d583ec4 100644
--- a/ignite/metrics/clustering/silhouette_score.py
+++ b/ignite/metrics/clustering/silhouette_score.py
@@ -111,7 +111,7 @@ def __init__(
def _silhouette_score(self, features: Tensor, labels: Tensor) -> float:
from sklearn.metrics import silhouette_score
- np_features = features.numpy()
- np_labels = labels.numpy()
+ np_features = features.cpu().numpy()
+ np_labels = labels.cpu().numpy()
score = silhouette_score(np_features, np_labels, **self._silhouette_kwargs)
return score
diff --git a/ignite/metrics/gpu_info.py b/ignite/metrics/gpu_info.py
index 96ed4f07c57c..d13bbd8a1dae 100644
--- a/ignite/metrics/gpu_info.py
+++ b/ignite/metrics/gpu_info.py
@@ -10,7 +10,7 @@
class GpuInfo(Metric):
"""Provides GPU information: a) used memory percentage, b) gpu utilization percentage values as Metric
- on each iterations.
+ on each iterations. This metric requires `pynvml `_ package of version `<12`.
.. Note ::
@@ -39,7 +39,7 @@ def __init__(self) -> None:
except ImportError:
raise ModuleNotFoundError(
"This contrib module requires pynvml to be installed. "
- "Please install it with command: \n pip install pynvml"
+ "Please install it with command: \n pip install 'pynvml<12'"
)
# Let's check available devices
if not torch.cuda.is_available():
diff --git a/ignite/metrics/regression/kendall_correlation.py b/ignite/metrics/regression/kendall_correlation.py
index 7ad87b224024..34d876a36599 100644
--- a/ignite/metrics/regression/kendall_correlation.py
+++ b/ignite/metrics/regression/kendall_correlation.py
@@ -16,8 +16,8 @@ def _get_kendall_tau(variant: str = "b") -> Callable[[Tensor, Tensor], float]:
raise ValueError(f"variant accepts 'b' or 'c', got {variant!r}.")
def _tau(predictions: Tensor, targets: Tensor) -> float:
- np_preds = predictions.flatten().numpy()
- np_targets = targets.flatten().numpy()
+ np_preds = predictions.flatten().cpu().numpy()
+ np_targets = targets.flatten().cpu().numpy()
r = kendalltau(np_preds, np_targets, variant=variant).statistic
return r
diff --git a/ignite/metrics/regression/spearman_correlation.py b/ignite/metrics/regression/spearman_correlation.py
index 7f126d6e56be..cbd89f67c9d0 100644
--- a/ignite/metrics/regression/spearman_correlation.py
+++ b/ignite/metrics/regression/spearman_correlation.py
@@ -12,8 +12,8 @@
def _spearman_r(predictions: Tensor, targets: Tensor) -> float:
from scipy.stats import spearmanr
- np_preds = predictions.flatten().numpy()
- np_targets = targets.flatten().numpy()
+ np_preds = predictions.flatten().cpu().numpy()
+ np_targets = targets.flatten().cpu().numpy()
r = spearmanr(np_preds, np_targets).statistic
return r
diff --git a/requirements-dev.txt b/requirements-dev.txt
index d475e556cdff..91b560e56530 100644
--- a/requirements-dev.txt
+++ b/requirements-dev.txt
@@ -21,7 +21,7 @@ mlflow
neptune-client>=0.16.17
tensorboard
torchvision
-pynvml
+pynvml<12 # pynvml module was removed in 12.X, is not developed or maintained. We should replace pynvml with something else.
clearml
scikit-image
py-rouge
diff --git a/tests/common_test_functionality.sh b/tests/common_test_functionality.sh
index 6e60947f927b..91003eddc092 100644
--- a/tests/common_test_functionality.sh
+++ b/tests/common_test_functionality.sh
@@ -85,7 +85,6 @@ run_tests() {
skip_distrib_opt=""
fi
-
echo [pytest] > pytest.ini ; echo "cache_dir=${cache_dir}" >> pytest.ini
# Assemble options for the pytest command
@@ -103,8 +102,8 @@ run_tests() {
# Run the command
if [ "$trap_deselected_exit_code" -eq "1" ]; then
- CUDA_VISIBLE_DEVICES="" eval "pytest ${pytest_args}" || { exit_code=$?; if [ "$exit_code" -eq ${last_failed_no_failures_code} ]; then echo "All tests deselected"; else exit $exit_code; fi; }
+ eval "pytest ${pytest_args}" || { exit_code=$?; if [ "$exit_code" -eq ${last_failed_no_failures_code} ]; then echo "All tests deselected"; else exit $exit_code; fi; }
else
- CUDA_VISIBLE_DEVICES="" eval "pytest ${pytest_args}"
+ eval "pytest ${pytest_args}"
fi
}
diff --git a/tests/ignite/contrib/engines/test_common.py b/tests/ignite/contrib/engines/test_common.py
index d0100be9e8da..e14042e62c15 100644
--- a/tests/ignite/contrib/engines/test_common.py
+++ b/tests/ignite/contrib/engines/test_common.py
@@ -8,7 +8,6 @@
from torch.utils.data.distributed import DistributedSampler
import ignite.distributed as idist
-
import ignite.handlers as handlers
from ignite.contrib.engines.common import (
_setup_logging,
diff --git a/tests/ignite/engine/test_engine.py b/tests/ignite/engine/test_engine.py
index 130212426504..fcb0299aa22d 100644
--- a/tests/ignite/engine/test_engine.py
+++ b/tests/ignite/engine/test_engine.py
@@ -40,11 +40,14 @@ class TestEngine:
def set_interrupt_resume_enabled(self, interrupt_resume_enabled):
Engine.interrupt_resume_enabled = interrupt_resume_enabled
- def test_terminate(self):
+ @pytest.mark.parametrize("skip_completed", [True, False])
+ def test_terminate(self, skip_completed):
engine = Engine(lambda e, b: 1)
assert not engine.should_terminate
- engine.terminate()
+ assert not engine.skip_completed_after_termination
+ engine.terminate(skip_completed)
assert engine.should_terminate
+ assert engine.skip_completed_after_termination == skip_completed
def test_invalid_process_raises_with_invalid_signature(self):
with pytest.raises(ValueError, match=r"Engine must be given a processing function in order to run"):
@@ -236,25 +239,32 @@ def check_iter_and_data():
assert num_calls_check_iter_epoch == 1
@pytest.mark.parametrize(
- "terminate_event, e, i",
+ "terminate_event, e, i, skip_completed",
[
- (Events.STARTED, 0, 0),
- (Events.EPOCH_STARTED(once=2), 2, None),
- (Events.EPOCH_COMPLETED(once=2), 2, None),
- (Events.GET_BATCH_STARTED(once=12), None, 12),
- (Events.GET_BATCH_COMPLETED(once=12), None, 12),
- (Events.ITERATION_STARTED(once=14), None, 14),
- (Events.ITERATION_COMPLETED(once=14), None, 14),
+ (Events.STARTED, 0, 0, True),
+ (Events.EPOCH_STARTED(once=2), 2, None, True),
+ (Events.EPOCH_COMPLETED(once=2), 2, None, True),
+ (Events.GET_BATCH_STARTED(once=12), None, 12, True),
+ (Events.GET_BATCH_COMPLETED(once=12), None, 12, False),
+ (Events.ITERATION_STARTED(once=14), None, 14, True),
+ (Events.ITERATION_COMPLETED(once=14), None, 14, True),
+ (Events.STARTED, 0, 0, False),
+ (Events.EPOCH_STARTED(once=2), 2, None, False),
+ (Events.EPOCH_COMPLETED(once=2), 2, None, False),
+ (Events.GET_BATCH_STARTED(once=12), None, 12, False),
+ (Events.GET_BATCH_COMPLETED(once=12), None, 12, False),
+ (Events.ITERATION_STARTED(once=14), None, 14, False),
+ (Events.ITERATION_COMPLETED(once=14), None, 14, False),
],
)
- def test_terminate_events_sequence(self, terminate_event, e, i):
+ def test_terminate_events_sequence(self, terminate_event, e, i, skip_completed):
engine = RecordedEngine(MagicMock(return_value=1))
data = range(10)
max_epochs = 5
@engine.on(terminate_event)
def call_terminate():
- engine.terminate()
+ engine.terminate(skip_completed)
@engine.on(Events.EXCEPTION_RAISED)
def assert_no_exceptions(ee):
@@ -271,10 +281,15 @@ def assert_no_exceptions(ee):
if e is None:
e = i // len(data) + 1
+ if skip_completed:
+ assert engine.called_events[-1] == (e, i, Events.TERMINATE)
+ assert engine.called_events[-2] == (e, i, terminate_event)
+ else:
+ assert engine.called_events[-1] == (e, i, Events.COMPLETED)
+ assert engine.called_events[-2] == (e, i, Events.TERMINATE)
+ assert engine.called_events[-3] == (e, i, terminate_event)
+
assert engine.called_events[0] == (0, 0, Events.STARTED)
- assert engine.called_events[-1] == (e, i, Events.COMPLETED)
- assert engine.called_events[-2] == (e, i, Events.TERMINATE)
- assert engine.called_events[-3] == (e, i, terminate_event)
assert engine._dataloader_iter is None
@pytest.mark.parametrize("data, epoch_length", [(None, 10), (range(10), None)])
diff --git a/tests/ignite/metrics/test_classification_report.py b/tests/ignite/metrics/test_classification_report.py
index 87e328c8051e..cae8b5145f55 100644
--- a/tests/ignite/metrics/test_classification_report.py
+++ b/tests/ignite/metrics/test_classification_report.py
@@ -164,6 +164,23 @@ def update(engine, i):
@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Skip if no GPU")
@pytest.mark.skipif(Version(torch.__version__) < Version("1.7.0"), reason="Skip if < 1.7.0")
def test_distrib_nccl_gpu(distributed_context_single_node_nccl):
+
+ pytest.skip("Temporarily skip failing test. See https://github.com/pytorch/ignite/pull/3301")
+ # When run with 2 devices:
+ # tests/ignite/metrics/test_classification_report.py::test_distrib_nccl_gpu Fatal Python error: Aborted
+ # Thread 0x00007fac95c95700 (most recent call first):
+ #
+
+ # Thread 0x00007facbb89b700 (most recent call first):
+ #
+
+ # Thread 0x00007fae637f4700 (most recent call first):
+ # File "", line 534 in read
+ # File "", line 567 in from_io
+ # File "", line 1160 in _thread_receiver
+ # File "", line 341 in run
+ # File "", line 411 in _perform_spawn
+
device = idist.device()
_test_integration_multiclass(device, True)
_test_integration_multiclass(device, False)
diff --git a/tests/ignite/metrics/test_hsic.py b/tests/ignite/metrics/test_hsic.py
index 57af5fa2862c..28fe5c1f97db 100644
--- a/tests/ignite/metrics/test_hsic.py
+++ b/tests/ignite/metrics/test_hsic.py
@@ -139,10 +139,10 @@ def test_integration(self, sigma_x: float, sigma_y: float):
metric_devices.append(device)
for metric_device in metric_devices:
- x = torch.randn((n_iters * batch_size, n_dims_x)).float().to(device)
+ x = torch.randn((n_iters * batch_size, n_dims_x), device=device).float()
lin = nn.Linear(n_dims_x, n_dims_y).to(device)
- y = torch.sin(lin(x) * 100) + torch.randn(n_iters * batch_size, n_dims_y) * 1e-4
+ y = torch.sin(lin(x) * 100) + torch.randn(n_iters * batch_size, n_dims_y, device=x.device) * 1e-4
def data_loader(i, input_x, input_y):
return input_x[i * batch_size : (i + 1) * batch_size], input_y[i * batch_size : (i + 1) * batch_size]
diff --git a/tests/run_cpu_tests.sh b/tests/run_cpu_tests.sh
index 8d387f5542e7..f52988a68183 100644
--- a/tests/run_cpu_tests.sh
+++ b/tests/run_cpu_tests.sh
@@ -6,8 +6,7 @@ skip_distrib_tests=${SKIP_DISTRIB_TESTS:-0}
use_last_failed=${USE_LAST_FAILED:-0}
match_tests_expression=${1:-""}
-
-run_tests \
+CUDA_VISIBLE_DEVICES="" run_tests \
--core_args "--tx 4*popen//python=python -vvv tests/ignite" \
--cache_dir ".cpu-not-distrib" \
--skip_distrib_tests "${skip_distrib_tests}" \
@@ -21,7 +20,7 @@ if [ "${skip_distrib_tests}" -eq "1" ]; then
fi
# Run 2 processes with --dist=each
-run_tests \
+CUDA_VISIBLE_DEVICES="" run_tests \
--core_args "-m distributed -vvv tests/ignite" \
--world_size 2 \
--cache_dir ".cpu-distrib" \
diff --git a/tests/run_gpu_tests.sh b/tests/run_gpu_tests.sh
index 26497f19c83e..c86d1d0746ee 100644
--- a/tests/run_gpu_tests.sh
+++ b/tests/run_gpu_tests.sh
@@ -2,26 +2,26 @@
source "$(dirname "$0")/common_test_functionality.sh"
set -xeu
-skip_distrib_tests=${SKIP_DISTRIB_TESTS:-1}
+# https://pubs.opengroup.org/onlinepubs/009695399/utilities/xcu_chap02.html#tag_02_06_02
+skip_distrib_tests=${SKIP_DISTRIB_TESTS:-0}
use_last_failed=${USE_LAST_FAILED:-0}
ngpus=${1:-1}
match_tests_expression=${2:-""}
if [ -z "$match_tests_expression" ]; then
- cuda_pattern="cuda"
+ cuda_pattern="cuda or nccl or gloo"
else
- cuda_pattern="cuda and $match_tests_expression"
+ cuda_pattern="(cuda or nccl or gloo) and $match_tests_expression"
fi
run_tests \
- --core_args "-vvv tests/ignite" \
+ --core_args "-vvv tests/ignite -m 'not distributed'" \
--cache_dir ".gpu-cuda" \
--skip_distrib_tests "${skip_distrib_tests}" \
--use_coverage 1 \
--match_tests_expression "${cuda_pattern}" \
--use_last_failed ${use_last_failed}
-# https://pubs.opengroup.org/onlinepubs/009695399/utilities/xcu_chap02.html#tag_02_06_02
if [ "${skip_distrib_tests}" -eq "1" ]; then
exit 0
fi