Skip to content

Commit

Permalink
Merge branch 'master' into fix-deprecated-statement
Browse files Browse the repository at this point in the history
  • Loading branch information
vfdev-5 authored Dec 4, 2024
2 parents 2ae8f3a + 6f8ad2a commit 1b0ecc2
Show file tree
Hide file tree
Showing 18 changed files with 119 additions and 64 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/gpu-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ jobs:
uses: nick-fields/retry@v2.9.0
with:
max_attempts: 5
timeout_minutes: 25
timeout_minutes: 45
shell: bash
command: docker exec -t pthd /bin/bash -xec 'bash tests/run_gpu_tests.sh 2'
new_command_on_retry: docker exec -e USE_LAST_FAILED=1 -t pthd /bin/bash -xec 'bash tests/run_gpu_tests.sh 2'
Expand Down
2 changes: 1 addition & 1 deletion ignite/contrib/engines/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def setup_common_training_handlers(
lr_scheduler: learning rate scheduler
as native torch LRScheduler or ignite's parameter scheduler.
with_gpu_stats: if True, :class:`~ignite.metrics.GpuInfo` is attached to the
trainer. This requires `pynvml` package to be installed.
trainer. This requires `pynvml<12` package to be installed.
output_names: list of names associated with `update_function` output dictionary.
with_pbars: if True, two progress bars on epochs and optionally on iterations are attached.
Default, True.
Expand Down
41 changes: 28 additions & 13 deletions ignite/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ def __init__(self, process_function: Callable[["Engine", Any], Any]):
self._process_function = process_function
self.last_event_name: Optional[Events] = None
self.should_terminate = False
self.skip_completed_after_termination = False
self.should_terminate_single_epoch = False
self.should_interrupt = False
self.state = State()
Expand Down Expand Up @@ -538,7 +539,7 @@ def call_interrupt():
self.logger.info("interrupt signaled. Engine will interrupt the run after current iteration is finished.")
self.should_interrupt = True

def terminate(self) -> None:
def terminate(self, skip_completed: bool = False) -> None:
"""Sends terminate signal to the engine, so that it terminates completely the run. The run is
terminated after the event on which ``terminate`` method was called. The following events are triggered:
Expand All @@ -547,6 +548,9 @@ def terminate(self) -> None:
- :attr:`~ignite.engine.events.Events.TERMINATE`
- :attr:`~ignite.engine.events.Events.COMPLETED`
Args:
skip_completed: if True, the event :attr:`~ignite.engine.events.Events.COMPLETED` is not fired after
:attr:`~ignite.engine.events.Events.TERMINATE`. Default is False.
Examples:
.. testcode::
Expand Down Expand Up @@ -617,9 +621,12 @@ def terminate():
.. versionchanged:: 0.4.10
Behaviour changed, for details see https://github.com/pytorch/ignite/issues/2669
.. versionchanged:: 0.5.2
Added `skip_completed` flag
"""
self.logger.info("Terminate signaled. Engine will stop after current iteration is finished.")
self.should_terminate = True
self.skip_completed_after_termination = skip_completed

def terminate_epoch(self) -> None:
"""Sends terminate signal to the engine, so that it terminates the current epoch. The run
Expand Down Expand Up @@ -993,13 +1000,17 @@ def _internal_run_as_gen(self) -> Generator[Any, None, State]:
time_taken = time.time() - start_time
# time is available for handlers but must be updated after fire
self.state.times[Events.COMPLETED.name] = time_taken
handlers_start_time = time.time()
self._fire_event(Events.COMPLETED)
time_taken += time.time() - handlers_start_time
# update time wrt handlers
self.state.times[Events.COMPLETED.name] = time_taken

# do not fire Events.COMPLETED if we terminated the run with flag `skip_completed=True`
if not (self.should_terminate and self.skip_completed_after_termination):
handlers_start_time = time.time()
self._fire_event(Events.COMPLETED)
time_taken += time.time() - handlers_start_time
# update time wrt handlers
self.state.times[Events.COMPLETED.name] = time_taken

hours, mins, secs = _to_hours_mins_secs(time_taken)
self.logger.info(f"Engine run complete. Time taken: {hours:02d}:{mins:02d}:{secs:06.3f}")
self.logger.info(f"Engine run finished. Time taken: {hours:02d}:{mins:02d}:{secs:06.3f}")

except BaseException as e:
self._dataloader_iter = None
Expand Down Expand Up @@ -1174,13 +1185,17 @@ def _internal_run_legacy(self) -> State:
time_taken = time.time() - start_time
# time is available for handlers but must be updated after fire
self.state.times[Events.COMPLETED.name] = time_taken
handlers_start_time = time.time()
self._fire_event(Events.COMPLETED)
time_taken += time.time() - handlers_start_time
# update time wrt handlers
self.state.times[Events.COMPLETED.name] = time_taken

# do not fire Events.COMPLETED if we terminated the run with flag `skip_completed=True`
if not (self.should_terminate and self.skip_completed_after_termination):
handlers_start_time = time.time()
self._fire_event(Events.COMPLETED)
time_taken += time.time() - handlers_start_time
# update time wrt handlers
self.state.times[Events.COMPLETED.name] = time_taken

hours, mins, secs = _to_hours_mins_secs(time_taken)
self.logger.info(f"Engine run complete. Time taken: {hours:02d}:{mins:02d}:{secs:06.3f}")
self.logger.info(f"Engine run finished. Time taken: {hours:02d}:{mins:02d}:{secs:06.3f}")

except BaseException as e:
self._dataloader_iter = None
Expand Down
25 changes: 18 additions & 7 deletions ignite/engine/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,36 +259,47 @@ class Events(EventEnum):
- TERMINATE_SINGLE_EPOCH : triggered when the run is about to end the current epoch,
after receiving a :meth:`~ignite.engine.engine.Engine.terminate_epoch()` or
:meth:`~ignite.engine.engine.Engine.terminate()` call.
- EPOCH_COMPLETED : triggered when the epoch is ended. Note that this is triggered even
when :meth:`~ignite.engine.engine.Engine.terminate_epoch()` is called.
- TERMINATE : triggered when the run is about to end completely,
after receiving :meth:`~ignite.engine.engine.Engine.terminate()` call.
- EPOCH_COMPLETED : triggered when the epoch is ended. Note that this is triggered even
when :meth:`~ignite.engine.engine.Engine.terminate_epoch()` is called.
- COMPLETED : triggered when engine's run is completed
- COMPLETED : triggered when engine's run is completed or terminated with
:meth:`~ignite.engine.engine.Engine.terminate()`, unless the flag
`skip_completed` is set to True.
The table below illustrates which events are triggered when various termination methods are called.
.. list-table::
:widths: 24 25 33 18
:widths: 35 38 28 20 20
:header-rows: 1
* - Method
- EVENT_COMPLETED
- TERMINATE_SINGLE_EPOCH
- EPOCH_COMPLETED
- TERMINATE
- COMPLETED
* - no termination
- ✔
- ✗
- ✔
- ✗
- ✔
* - :meth:`~ignite.engine.engine.Engine.terminate_epoch()`
- ✔
- ✔
- ✗
- ✔
* - :meth:`~ignite.engine.engine.Engine.terminate()`
- ✗
- ✔
- ✔
- ✔
* - :meth:`~ignite.engine.engine.Engine.terminate()` with `skip_completed=True`
- ✗
- ✔
- ✔
- ✗
Since v0.3.0, Events become more flexible and allow to pass an event filter to the Engine:
Expand Down Expand Up @@ -357,7 +368,7 @@ class CustomEvents(EventEnum):
STARTED = "started"
"""triggered when engine's run is started."""
COMPLETED = "completed"
"""triggered when engine's run is completed"""
"""triggered when engine's run is completed, or after receiving terminate() call."""

ITERATION_STARTED = "iteration_started"
"""triggered when an iteration is started."""
Expand Down
4 changes: 2 additions & 2 deletions ignite/metrics/clustering/calinski_harabasz_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
def _calinski_harabasz_score(features: Tensor, labels: Tensor) -> float:
from sklearn.metrics import calinski_harabasz_score

np_features = features.numpy()
np_labels = labels.numpy()
np_features = features.cpu().numpy()
np_labels = labels.cpu().numpy()
score = calinski_harabasz_score(np_features, np_labels)
return score

Expand Down
4 changes: 2 additions & 2 deletions ignite/metrics/clustering/davies_bouldin_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
def _davies_bouldin_score(features: Tensor, labels: Tensor) -> float:
from sklearn.metrics import davies_bouldin_score

np_features = features.numpy()
np_labels = labels.numpy()
np_features = features.cpu().numpy()
np_labels = labels.cpu().numpy()
score = davies_bouldin_score(np_features, np_labels)
return score

Expand Down
4 changes: 2 additions & 2 deletions ignite/metrics/clustering/silhouette_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def __init__(
def _silhouette_score(self, features: Tensor, labels: Tensor) -> float:
from sklearn.metrics import silhouette_score

np_features = features.numpy()
np_labels = labels.numpy()
np_features = features.cpu().numpy()
np_labels = labels.cpu().numpy()
score = silhouette_score(np_features, np_labels, **self._silhouette_kwargs)
return score
4 changes: 2 additions & 2 deletions ignite/metrics/gpu_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

class GpuInfo(Metric):
"""Provides GPU information: a) used memory percentage, b) gpu utilization percentage values as Metric
on each iterations.
on each iterations. This metric requires `pynvml <https://pypi.org/project/pynvml/>`_ package of version `<12`.
.. Note ::
Expand Down Expand Up @@ -39,7 +39,7 @@ def __init__(self) -> None:
except ImportError:
raise ModuleNotFoundError(
"This contrib module requires pynvml to be installed. "
"Please install it with command: \n pip install pynvml"
"Please install it with command: \n pip install 'pynvml<12'"
)
# Let's check available devices
if not torch.cuda.is_available():
Expand Down
4 changes: 2 additions & 2 deletions ignite/metrics/regression/kendall_correlation.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ def _get_kendall_tau(variant: str = "b") -> Callable[[Tensor, Tensor], float]:
raise ValueError(f"variant accepts 'b' or 'c', got {variant!r}.")

def _tau(predictions: Tensor, targets: Tensor) -> float:
np_preds = predictions.flatten().numpy()
np_targets = targets.flatten().numpy()
np_preds = predictions.flatten().cpu().numpy()
np_targets = targets.flatten().cpu().numpy()
r = kendalltau(np_preds, np_targets, variant=variant).statistic
return r

Expand Down
4 changes: 2 additions & 2 deletions ignite/metrics/regression/spearman_correlation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
def _spearman_r(predictions: Tensor, targets: Tensor) -> float:
from scipy.stats import spearmanr

np_preds = predictions.flatten().numpy()
np_targets = targets.flatten().numpy()
np_preds = predictions.flatten().cpu().numpy()
np_targets = targets.flatten().cpu().numpy()
r = spearmanr(np_preds, np_targets).statistic
return r

Expand Down
2 changes: 1 addition & 1 deletion requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ mlflow
neptune-client>=0.16.17
tensorboard
torchvision
pynvml
pynvml<12 # pynvml module was removed in 12.X, is not developed or maintained. We should replace pynvml with something else.
clearml
scikit-image
py-rouge
Expand Down
5 changes: 2 additions & 3 deletions tests/common_test_functionality.sh
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,6 @@ run_tests() {
skip_distrib_opt=""
fi


echo [pytest] > pytest.ini ; echo "cache_dir=${cache_dir}" >> pytest.ini

# Assemble options for the pytest command
Expand All @@ -103,8 +102,8 @@ run_tests() {

# Run the command
if [ "$trap_deselected_exit_code" -eq "1" ]; then
CUDA_VISIBLE_DEVICES="" eval "pytest ${pytest_args}" || { exit_code=$?; if [ "$exit_code" -eq ${last_failed_no_failures_code} ]; then echo "All tests deselected"; else exit $exit_code; fi; }
eval "pytest ${pytest_args}" || { exit_code=$?; if [ "$exit_code" -eq ${last_failed_no_failures_code} ]; then echo "All tests deselected"; else exit $exit_code; fi; }
else
CUDA_VISIBLE_DEVICES="" eval "pytest ${pytest_args}"
eval "pytest ${pytest_args}"
fi
}
1 change: 0 additions & 1 deletion tests/ignite/contrib/engines/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from torch.utils.data.distributed import DistributedSampler

import ignite.distributed as idist

import ignite.handlers as handlers
from ignite.contrib.engines.common import (
_setup_logging,
Expand Down
45 changes: 30 additions & 15 deletions tests/ignite/engine/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,14 @@ class TestEngine:
def set_interrupt_resume_enabled(self, interrupt_resume_enabled):
Engine.interrupt_resume_enabled = interrupt_resume_enabled

def test_terminate(self):
@pytest.mark.parametrize("skip_completed", [True, False])
def test_terminate(self, skip_completed):
engine = Engine(lambda e, b: 1)
assert not engine.should_terminate
engine.terminate()
assert not engine.skip_completed_after_termination
engine.terminate(skip_completed)
assert engine.should_terminate
assert engine.skip_completed_after_termination == skip_completed

def test_invalid_process_raises_with_invalid_signature(self):
with pytest.raises(ValueError, match=r"Engine must be given a processing function in order to run"):
Expand Down Expand Up @@ -236,25 +239,32 @@ def check_iter_and_data():
assert num_calls_check_iter_epoch == 1

@pytest.mark.parametrize(
"terminate_event, e, i",
"terminate_event, e, i, skip_completed",
[
(Events.STARTED, 0, 0),
(Events.EPOCH_STARTED(once=2), 2, None),
(Events.EPOCH_COMPLETED(once=2), 2, None),
(Events.GET_BATCH_STARTED(once=12), None, 12),
(Events.GET_BATCH_COMPLETED(once=12), None, 12),
(Events.ITERATION_STARTED(once=14), None, 14),
(Events.ITERATION_COMPLETED(once=14), None, 14),
(Events.STARTED, 0, 0, True),
(Events.EPOCH_STARTED(once=2), 2, None, True),
(Events.EPOCH_COMPLETED(once=2), 2, None, True),
(Events.GET_BATCH_STARTED(once=12), None, 12, True),
(Events.GET_BATCH_COMPLETED(once=12), None, 12, False),
(Events.ITERATION_STARTED(once=14), None, 14, True),
(Events.ITERATION_COMPLETED(once=14), None, 14, True),
(Events.STARTED, 0, 0, False),
(Events.EPOCH_STARTED(once=2), 2, None, False),
(Events.EPOCH_COMPLETED(once=2), 2, None, False),
(Events.GET_BATCH_STARTED(once=12), None, 12, False),
(Events.GET_BATCH_COMPLETED(once=12), None, 12, False),
(Events.ITERATION_STARTED(once=14), None, 14, False),
(Events.ITERATION_COMPLETED(once=14), None, 14, False),
],
)
def test_terminate_events_sequence(self, terminate_event, e, i):
def test_terminate_events_sequence(self, terminate_event, e, i, skip_completed):
engine = RecordedEngine(MagicMock(return_value=1))
data = range(10)
max_epochs = 5

@engine.on(terminate_event)
def call_terminate():
engine.terminate()
engine.terminate(skip_completed)

@engine.on(Events.EXCEPTION_RAISED)
def assert_no_exceptions(ee):
Expand All @@ -271,10 +281,15 @@ def assert_no_exceptions(ee):
if e is None:
e = i // len(data) + 1

if skip_completed:
assert engine.called_events[-1] == (e, i, Events.TERMINATE)
assert engine.called_events[-2] == (e, i, terminate_event)
else:
assert engine.called_events[-1] == (e, i, Events.COMPLETED)
assert engine.called_events[-2] == (e, i, Events.TERMINATE)
assert engine.called_events[-3] == (e, i, terminate_event)

assert engine.called_events[0] == (0, 0, Events.STARTED)
assert engine.called_events[-1] == (e, i, Events.COMPLETED)
assert engine.called_events[-2] == (e, i, Events.TERMINATE)
assert engine.called_events[-3] == (e, i, terminate_event)
assert engine._dataloader_iter is None

@pytest.mark.parametrize("data, epoch_length", [(None, 10), (range(10), None)])
Expand Down
17 changes: 17 additions & 0 deletions tests/ignite/metrics/test_classification_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,23 @@ def update(engine, i):
@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Skip if no GPU")
@pytest.mark.skipif(Version(torch.__version__) < Version("1.7.0"), reason="Skip if < 1.7.0")
def test_distrib_nccl_gpu(distributed_context_single_node_nccl):

pytest.skip("Temporarily skip failing test. See https://github.com/pytorch/ignite/pull/3301")
# When run with 2 devices:
# tests/ignite/metrics/test_classification_report.py::test_distrib_nccl_gpu Fatal Python error: Aborted
# Thread 0x00007fac95c95700 (most recent call first):
# <no Python frame>

# Thread 0x00007facbb89b700 (most recent call first):
# <no Python frame>

# Thread 0x00007fae637f4700 (most recent call first):
# File "<string>", line 534 in read
# File "<string>", line 567 in from_io
# File "<string>", line 1160 in _thread_receiver
# File "<string>", line 341 in run
# File "<string>", line 411 in _perform_spawn

device = idist.device()
_test_integration_multiclass(device, True)
_test_integration_multiclass(device, False)
Expand Down
4 changes: 2 additions & 2 deletions tests/ignite/metrics/test_hsic.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,10 +139,10 @@ def test_integration(self, sigma_x: float, sigma_y: float):
metric_devices.append(device)

for metric_device in metric_devices:
x = torch.randn((n_iters * batch_size, n_dims_x)).float().to(device)
x = torch.randn((n_iters * batch_size, n_dims_x), device=device).float()

lin = nn.Linear(n_dims_x, n_dims_y).to(device)
y = torch.sin(lin(x) * 100) + torch.randn(n_iters * batch_size, n_dims_y) * 1e-4
y = torch.sin(lin(x) * 100) + torch.randn(n_iters * batch_size, n_dims_y, device=x.device) * 1e-4

def data_loader(i, input_x, input_y):
return input_x[i * batch_size : (i + 1) * batch_size], input_y[i * batch_size : (i + 1) * batch_size]
Expand Down
Loading

0 comments on commit 1b0ecc2

Please sign in to comment.