This repository has been archived by the owner on Jul 1, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 275
/
Copy pathhooks_exponential_moving_average_model_hook_test.py
130 lines (110 loc) · 4.87 KB
/
hooks_exponential_moving_average_model_hook_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import copy
import math
import unittest
import unittest.mock as mock
import torch
import torch.nn as nn
from classy_vision.hooks import ExponentialMovingAverageModelHook
from classy_vision.models import ClassyModel
from test.generic.hook_test_utils import HookTestBase
class TestModel(ClassyModel):
def __init__(self):
super().__init__()
self.fc = nn.Linear(10, 10)
self.bn = nn.BatchNorm1d(10)
def init_fc_weight(self):
nn.init.zeros_(self.fc.weight)
def update_fc_weight(self):
nn.init.ones_(self.fc.weight)
def forward(self, x):
return self.bn(self.fc(x))
class TestExponentialMovingAverageModelHook(HookTestBase):
def _map_device_string(self, device):
return "cuda" if device == "gpu" else "cpu"
def _test_exponential_moving_average_hook(self, model_device, hook_device):
task = mock.MagicMock()
model = TestModel().to(device=self._map_device_string(model_device))
task.base_model = model
task.train = True
decay = 0.5
num_updates = 10
model.init_fc_weight()
exponential_moving_average_hook = ExponentialMovingAverageModelHook(
decay=decay, device=hook_device
)
exponential_moving_average_hook.on_start(task)
exponential_moving_average_hook.on_phase_start(task)
# set the weights to all ones and simulate 10 updates
task.base_model.update_fc_weight()
fc_weight = model.fc.weight.clone()
for _ in range(num_updates):
exponential_moving_average_hook.on_step(task)
exponential_moving_average_hook.on_phase_end(task)
# the model weights shouldn't have changed
self.assertTrue(torch.allclose(model.fc.weight, fc_weight))
# simulate a test phase now
task.train = False
exponential_moving_average_hook.on_phase_start(task)
exponential_moving_average_hook.on_phase_end(task)
# the model weights should be updated to the ema weights
self.assertTrue(
torch.allclose(
model.fc.weight, fc_weight * (1 - math.pow(1 - decay, num_updates))
)
)
# simulate a train phase again
task.train = True
exponential_moving_average_hook.on_phase_start(task)
# the model weights should be back to the old value
self.assertTrue(torch.allclose(model.fc.weight, fc_weight))
def test_constructors(self) -> None:
"""
Test that the hooks are constructed correctly.
"""
config = {"decay": 0.5, "consider_bn_buffers": True, "device": "cpu"}
invalid_config1 = copy.deepcopy(config)
del invalid_config1["decay"]
invalid_config2 = copy.deepcopy(config)
invalid_config2["device"] = "crazy_hardware"
self.constructor_test_helper(
config=config,
hook_type=ExponentialMovingAverageModelHook,
hook_registry_name="ema_model_weights",
invalid_configs=[invalid_config1, invalid_config2],
)
def test_get_model_state_iterator(self):
device = "gpu" if torch.cuda.is_available() else "cpu"
model = TestModel().to(device=self._map_device_string(device))
decay = 0.5
# test that we pick up the right parameters in the iterator
for consider_bn_buffers in [True, False]:
exponential_moving_average_hook = ExponentialMovingAverageModelHook(
decay=decay, consider_bn_buffers=consider_bn_buffers, device=device
)
iterable = exponential_moving_average_hook.get_model_state_iterator(model)
fc_found = False
bn_found = False
bn_buffer_found = False
for _, param in iterable:
if any(param is item for item in model.fc.parameters()):
fc_found = True
if any(param is item for item in model.bn.parameters()):
bn_found = True
if any(param is item for item in model.bn.buffers()):
bn_buffer_found = True
self.assertTrue(fc_found)
self.assertTrue(bn_found)
self.assertEqual(bn_buffer_found, consider_bn_buffers)
def test_exponential_moving_average_hook(self):
device = "gpu" if torch.cuda.is_available() else "cpu"
self._test_exponential_moving_average_hook(device, device)
@unittest.skipUnless(torch.cuda.is_available(), "This test needs a gpu to run")
def test_mixed_devices(self):
"""Tests that the hook works when the model and hook's device are different"""
self._test_exponential_moving_average_hook("cpu", "gpu")
self._test_exponential_moving_average_hook("gpu", "cpu")