@@ -68,6 +68,13 @@ def test_linear_scheduler_asserts():
68
68
with pytest .raises (ValueError , match = r"Argument cycle_size should be positive and larger than 1" ):
69
69
LinearCyclicalScheduler (optimizer , "lr" , 1 , 0 , cycle_size = 1 )
70
70
71
+ with pytest .raises (
72
+ ValueError ,
73
+ match = r"Invalid combination when warmup_duration > 0 and monotonic=False, "
74
+ r"please use either set warmup_duration=0 or monotonic=True" ,
75
+ ):
76
+ LinearCyclicalScheduler (optimizer , "lr" , 1 , 0 , cycle_size = 2 , warmup_duration = 1 )
77
+
71
78
72
79
def test_linear_scheduler ():
73
80
tensor = torch .zeros ([1 ], requires_grad = True )
@@ -144,6 +151,102 @@ def save_lr(engine):
144
151
scheduler .load_state_dict (state_dict )
145
152
146
153
154
+ def test_linear_scheduler_warmup_duration ():
155
+ tensor = torch .zeros ([1 ], requires_grad = True )
156
+ optimizer = torch .optim .SGD ([tensor ], lr = 0.0 )
157
+
158
+ scheduler = LinearCyclicalScheduler (optimizer , "lr" , 1 , 0 , 10 , warmup_duration = 5 , monotonic = True )
159
+ state_dict = scheduler .state_dict ()
160
+
161
+ def save_lr (engine ):
162
+ lrs .append (optimizer .param_groups [0 ]["lr" ])
163
+
164
+ trainer = Engine (lambda engine , batch : None )
165
+ trainer .add_event_handler (Events .ITERATION_STARTED , scheduler )
166
+ trainer .add_event_handler (Events .ITERATION_COMPLETED , save_lr )
167
+ lr_values_in_cycle = [
168
+ 1.0 ,
169
+ 0.9 ,
170
+ 0.8 ,
171
+ 0.7 ,
172
+ 0.6 ,
173
+ 0.5 ,
174
+ 0.4 ,
175
+ 0.3 ,
176
+ 0.2 ,
177
+ 0.1 ,
178
+ 0.0 ,
179
+ 0.2 ,
180
+ 0.4 ,
181
+ 0.6 ,
182
+ 0.8 ,
183
+ 1.0 ,
184
+ 0.9 ,
185
+ 0.8 ,
186
+ 0.7 ,
187
+ 0.6 ,
188
+ ]
189
+ for _ in range (2 ):
190
+ lrs = []
191
+ trainer .run ([0 ] * 10 , max_epochs = 2 )
192
+
193
+ assert lrs == pytest .approx (lr_values_in_cycle )
194
+ scheduler .load_state_dict (state_dict )
195
+
196
+ optimizer = torch .optim .SGD ([tensor ], lr = 0 )
197
+ scheduler = LinearCyclicalScheduler (optimizer , "lr" , 1 , 0 , 10 , cycle_mult = 2 , warmup_duration = 5 , monotonic = True )
198
+ state_dict = scheduler .state_dict ()
199
+
200
+ trainer = Engine (lambda engine , batch : None )
201
+ trainer .add_event_handler (Events .ITERATION_STARTED , scheduler )
202
+ trainer .add_event_handler (Events .ITERATION_COMPLETED , save_lr )
203
+
204
+ for _ in range (2 ):
205
+ lrs = []
206
+ trainer .run ([0 ] * 10 , max_epochs = 3 )
207
+
208
+ assert lrs == list (
209
+ map (
210
+ pytest .approx ,
211
+ [
212
+ # Cycle 1
213
+ 1.0 ,
214
+ 0.9 ,
215
+ 0.8 ,
216
+ 0.7 ,
217
+ 0.6 ,
218
+ 0.5 ,
219
+ 0.4 ,
220
+ 0.3 ,
221
+ 0.2 ,
222
+ 0.1 ,
223
+ 0.0 ,
224
+ 0.2 ,
225
+ 0.4 ,
226
+ 0.6 ,
227
+ 0.8 ,
228
+ # Cycle 2
229
+ 1.0 ,
230
+ 0.95 ,
231
+ 0.9 ,
232
+ 0.85 ,
233
+ 0.8 ,
234
+ 0.75 ,
235
+ 0.7 ,
236
+ 0.65 ,
237
+ 0.6 ,
238
+ 0.55 ,
239
+ 0.5 ,
240
+ 0.45 ,
241
+ 0.4 ,
242
+ 0.35 ,
243
+ 0.3 ,
244
+ ],
245
+ )
246
+ )
247
+ scheduler .load_state_dict (state_dict )
248
+
249
+
147
250
def test_linear_scheduler_cycle_size_two ():
148
251
tensor = torch .zeros ([1 ], requires_grad = True )
149
252
optimizer = torch .optim .SGD ([tensor ], lr = 0 )
0 commit comments