12
12
torch .Tensor .ndim = property (lambda x : len (x .shape ))
13
13
14
14
15
- def listify (o ):
15
+ def listify (o = None ) -> list :
16
+ """
17
+ Convert `o` to list. If `o` is None, return empty list.
18
+ """
16
19
if o is None :
17
20
return []
18
21
if isinstance (o , list ):
@@ -24,34 +27,53 @@ def listify(o):
24
27
return [o ]
25
28
26
29
27
- def annealer (f ):
30
+ def annealer (f ) -> callable :
31
+ """
32
+ A decorator function for creating a partially applied function with predefined start and end arguments.
33
+ The inner function `_inner` captures the `start` and `end` parameters and returns a `partial` object that fixes these parameters for the decorated function `f`.
34
+ """
28
35
def _inner (start , end ):
29
36
return partial (f , start , end )
30
37
31
38
return _inner
32
39
33
40
34
41
@annealer
35
- def sched_lin (start , end , pos ):
42
+ def sched_lin (start : float , end : float , pos : float ) -> float :
43
+ """
44
+ A linear schedule function.
45
+ """
36
46
return start + pos * (end - start )
37
47
38
48
39
49
@annealer
40
- def sched_cos (start , end , pos ):
50
+ def sched_cos (start : float , end : float , pos : float ) -> float :
51
+ """
52
+ A cosine schedule function.
53
+ """
41
54
return start + (1 + math .cos (math .pi * (1 - pos ))) * (end - start ) / 2
42
55
43
56
44
57
@annealer
45
- def sched_no (start , end , pos ):
58
+ def sched_no (start : float , end : float , pos : float ) -> float :
59
+ """
60
+ Disabled scheduling.
61
+ """
46
62
return start
47
63
48
64
49
65
@annealer
50
- def sched_exp (start , end , pos ):
66
+ def sched_exp (start : float , end : float , pos : float ) -> float :
67
+ """
68
+ Exponential schedule function.
69
+ """
51
70
return start * (end / start ) ** pos
52
71
53
72
54
- def combine_scheds (pcts , scheds ):
73
+ def combine_scheds (pcts : Iterable [float ], scheds : Iterable [callable ]) -> callable :
74
+ """
75
+ Combine multiple scheduling functions.
76
+ """
55
77
assert sum (pcts ) == 1.0
56
78
pcts = torch .tensor ([0 ] + listify (pcts ))
57
79
assert torch .all (pcts >= 0 )
@@ -65,20 +87,30 @@ def _inner(pos):
65
87
return _inner
66
88
67
89
68
- # def cos_1cycle_anneal(start, high, end):
69
- # return [sched_cos(start, high), sched_cos(high, end)]
70
-
90
+ def camel2snake (name : str ) -> str :
91
+ """
92
+ Convert `name` from camel case to snake case.
93
+ """
94
+ s1 = re .sub (_camel_re1 , r"\1_\2" , name )
95
+ return re .sub (_camel_re2 , r"\1_\2" , s1 ).lower ()
71
96
72
97
73
98
class Callback :
99
+ """
100
+
101
+ Callback class is a base class designed for handling different callback functions during
102
+ an event-driven process. It provides functionality to set a runner, retrieve the class
103
+ name in snake_case format, directly call callback methods, and delegate attribute access
104
+ to the runner if the attribute does not exist in the Callback class.
105
+
106
+ The _order is used to decide the order of Callbacks.
107
+
108
+ """
74
109
_order = 0
75
110
76
- def set_runner (self , run ):
111
+ def set_runner (self , run ) -> None :
77
112
self .run = run
78
113
79
- def __getattr__ (self , k ):
80
- return getattr (self .run , k )
81
-
82
114
@property
83
115
def name (self ):
84
116
name = re .sub (r"Callback$" , "" , self .__class__ .__name__ )
@@ -90,59 +122,121 @@ def __call__(self, cb_name):
90
122
return True
91
123
return False
92
124
125
+ def __getattr__ (self , k ):
126
+ return getattr (self .run , k )
127
+
93
128
94
129
class ParamScheduler (Callback ):
130
+ """
131
+ Manages scheduling of parameter adjustments over the course of training.
132
+ """
95
133
_order = 1
96
134
97
135
def __init__ (self , pname , sched_funcs ):
98
136
self .pname , self .sched_funcs = pname , sched_funcs
99
137
100
138
def begin_fit (self ):
139
+ """
140
+ Prepare the scheduler at the start of the fitting process.
141
+ This method ensures that sched_funcs is a list with one function per parameter group.
142
+ """
101
143
if not isinstance (self .sched_funcs , (list , tuple )):
102
144
self .sched_funcs = [self .sched_funcs ] * len (self .opt .param_groups )
103
145
104
146
def set_param (self ):
147
+ """
148
+ Adjust the parameter value for each parameter group based on the scheduling function.
149
+ Ensures the number of scheduling functions matches the number of parameter groups.
150
+ """
105
151
assert len (self .opt .param_groups ) == len (self .sched_funcs )
106
152
for pg , f in zip (self .opt .param_groups , self .sched_funcs ):
107
153
pg [self .pname ] = f (self .n_epochs / self .epochs )
108
154
109
155
def begin_batch (self ):
156
+ """
157
+ Apply parameter adjustments at the beginning of each batch if in training mode.
158
+ """
110
159
if self .in_train :
111
160
self .set_param ()
112
161
113
-
114
162
class Recorder (Callback ):
163
+ """
164
+ Recorder is a callback class used to record learning rates and losses during the training process.
165
+ """
115
166
def begin_fit (self ):
167
+ """
168
+ Initializes attributes necessary for the fitting process.
169
+
170
+ Sets up learning rates and losses storage.
171
+
172
+ Attributes:
173
+ self.lrs (list): A list of lists, where each inner list will hold learning rates for a parameter group.
174
+ self.losses (list): An empty list to store loss values during the fitting process.
175
+ """
116
176
self .lrs = [[] for _ in self .opt .param_groups ]
117
177
self .losses = []
118
178
119
179
def after_batch (self ):
180
+ """
181
+ Handles operations to execute after each training batch.
182
+
183
+ Modifies the learning rate for each parameter group in the optimizer
184
+ and appends the current learning rate and loss to the corresponding lists.
185
+
186
+ """
120
187
if not self .in_train :
121
188
return
122
189
for pg , lr in zip (self .opt .param_groups , self .lrs ):
123
190
lr .append (pg ["lr" ])
124
191
self .losses .append (self .loss .detach ().cpu ())
125
192
126
193
def plot_lr (self , pgid = - 1 ):
194
+ """
195
+ Plots the learning rate for a given parameter group.
196
+ """
127
197
plt .plot (self .lrs [pgid ])
128
198
129
199
def plot_loss (self , skip_last = 0 ):
200
+ """
201
+ Plots the loss for a given parameter group.
202
+ """
130
203
plt .plot (self .losses [: len (self .losses ) - skip_last ])
131
204
132
205
def plot (self , skip_last = 0 , pgid = - 1 ):
206
+ """
207
+ Generates a plot of the loss values against the learning rates.
208
+ """
133
209
losses = [o .item () for o in self .losses ]
134
210
lrs = self .lrs [pgid ]
135
211
n = len (losses ) - skip_last
136
212
plt .xscale ("log" )
137
213
plt .plot (lrs [:n ], losses [:n ])
138
214
139
215
140
- def camel2snake (name ):
141
- s1 = re .sub (_camel_re1 , r"\1_\2" , name )
142
- return re .sub (_camel_re2 , r"\1_\2" , s1 ).lower ()
216
+ class TrainEvalCallback (Callback ):
217
+ """
218
+ TrainEvalCallback class is a custom callback used during the training
219
+ and validation phases of a machine learning model to perform specific
220
+ actions at the beginning and after certain events.
143
221
222
+ Methods:
144
223
145
- class TrainEvalCallback (Callback ):
224
+ begin_fit():
225
+ Initialize the number of epochs and iteration counts at the start
226
+ of the fitting process.
227
+
228
+ after_batch():
229
+ Update the epoch and iteration counts after each batch during
230
+ training.
231
+
232
+ begin_epoch():
233
+ Set the current epoch, switch the model to training mode, and
234
+ indicate that the model is in training.
235
+
236
+ begin_validate():
237
+ Switch the model to evaluation mode and indicate that the model
238
+ is in validation.
239
+ """
146
240
def begin_fit (self ):
147
241
self .run .n_epochs = 0
148
242
self .run .n_iter = 0
@@ -176,6 +270,32 @@ class CancelBatchException(Exception):
176
270
177
271
178
272
class AvgStats :
273
+ """
274
+ AvgStats class is used to track and accumulate average statistics (like loss and other metrics) during training and validation phases.
275
+
276
+ Attributes:
277
+ metrics (list): A list of metric functions to be tracked.
278
+ in_train (bool): A flag to indicate if the statistics are for the training phase.
279
+
280
+ Methods:
281
+ __init__(metrics, in_train):
282
+ Initializes the AvgStats with metrics and in_train flag.
283
+
284
+ reset():
285
+ Resets the accumulated statistics.
286
+
287
+ all_stats:
288
+ Property that returns all accumulated statistics including loss and metrics.
289
+
290
+ avg_stats:
291
+ Property that returns the average of the accumulated statistics.
292
+
293
+ accumulate(run):
294
+ Accumulates the statistics using the data from the given run.
295
+
296
+ __repr__():
297
+ Returns a string representation of the average statistics.
298
+ """
179
299
def __init__ (self , metrics , in_train ):
180
300
self .metrics , self .in_train = listify (metrics ), in_train
181
301
@@ -191,20 +311,32 @@ def all_stats(self):
191
311
def avg_stats (self ):
192
312
return [o / self .count for o in self .all_stats ]
193
313
194
- def __repr__ (self ):
195
- if not self .count :
196
- return ""
197
- return f"{ 'train' if self .in_train else 'valid' } : { self .avg_stats } "
198
-
199
314
def accumulate (self , run ):
200
315
bn = run .xb .shape [0 ]
201
316
self .tot_loss += run .loss * bn
202
317
self .count += bn
203
318
for i , m in enumerate (self .metrics ):
204
319
self .tot_mets [i ] += m (run .pred , run .yb ) * bn
205
320
321
+ def __repr__ (self ):
322
+ if not self .count :
323
+ return ""
324
+ return f"{ 'train' if self .in_train else 'valid' } : { self .avg_stats } "
325
+
206
326
207
327
class AvgStatsCallBack (Callback ):
328
+ """
329
+ AvgStatsCallBack class is a custom callback used to track and print average statistics for training and validation phases during the training loop.
330
+
331
+ Arguments:
332
+ metrics: A list of metric functions to evaluate during training and validation.
333
+
334
+ Methods:
335
+ __init__: Initializes the callback with given metrics and sets up AvgStats objects for both training and validation phases.
336
+ begin_epoch: Resets the statistics at the beginning of each epoch.
337
+ after_loss: Accumulates the metrics after computing the loss, differentiating between training and validation phases.
338
+ after_epoch: Prints the accumulated statistics for both training and validation phases after each epoch.
339
+ """
208
340
def __init__ (self , metrics ):
209
341
self .train_stats , self .valid_stats = AvgStats (metrics , True ), AvgStats (
210
342
metrics , False
0 commit comments