Skip to content

Commit 433958f

Browse files
refactored to make it easier to modify and add new things, updated documentation, squished some bugs
1 parent ec28a45 commit 433958f

File tree

4 files changed

+509
-392
lines changed

4 files changed

+509
-392
lines changed

openml_pytorch/callbacks.py

Lines changed: 156 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,10 @@
1212
torch.Tensor.ndim = property(lambda x: len(x.shape))
1313

1414

15-
def listify(o):
15+
def listify(o = None) -> list:
16+
"""
17+
Convert `o` to list. If `o` is None, return empty list.
18+
"""
1619
if o is None:
1720
return []
1821
if isinstance(o, list):
@@ -24,34 +27,53 @@ def listify(o):
2427
return [o]
2528

2629

27-
def annealer(f):
30+
def annealer(f) -> callable:
31+
"""
32+
A decorator function for creating a partially applied function with predefined start and end arguments.
33+
The inner function `_inner` captures the `start` and `end` parameters and returns a `partial` object that fixes these parameters for the decorated function `f`.
34+
"""
2835
def _inner(start, end):
2936
return partial(f, start, end)
3037

3138
return _inner
3239

3340

3441
@annealer
35-
def sched_lin(start, end, pos):
42+
def sched_lin(start: float, end: float, pos: float) -> float:
43+
"""
44+
A linear schedule function.
45+
"""
3646
return start + pos * (end - start)
3747

3848

3949
@annealer
40-
def sched_cos(start, end, pos):
50+
def sched_cos(start: float, end: float, pos: float) -> float:
51+
"""
52+
A cosine schedule function.
53+
"""
4154
return start + (1 + math.cos(math.pi * (1 - pos))) * (end - start) / 2
4255

4356

4457
@annealer
45-
def sched_no(start, end, pos):
58+
def sched_no(start: float, end: float, pos: float) -> float:
59+
"""
60+
Disabled scheduling.
61+
"""
4662
return start
4763

4864

4965
@annealer
50-
def sched_exp(start, end, pos):
66+
def sched_exp(start: float, end: float, pos: float) -> float:
67+
"""
68+
Exponential schedule function.
69+
"""
5170
return start * (end / start) ** pos
5271

5372

54-
def combine_scheds(pcts, scheds):
73+
def combine_scheds(pcts: Iterable[float], scheds: Iterable[callable]) -> callable:
74+
"""
75+
Combine multiple scheduling functions.
76+
"""
5577
assert sum(pcts) == 1.0
5678
pcts = torch.tensor([0] + listify(pcts))
5779
assert torch.all(pcts >= 0)
@@ -65,20 +87,30 @@ def _inner(pos):
6587
return _inner
6688

6789

68-
# def cos_1cycle_anneal(start, high, end):
69-
# return [sched_cos(start, high), sched_cos(high, end)]
70-
90+
def camel2snake(name : str) -> str:
91+
"""
92+
Convert `name` from camel case to snake case.
93+
"""
94+
s1 = re.sub(_camel_re1, r"\1_\2", name)
95+
return re.sub(_camel_re2, r"\1_\2", s1).lower()
7196

7297

7398
class Callback:
99+
"""
100+
101+
Callback class is a base class designed for handling different callback functions during
102+
an event-driven process. It provides functionality to set a runner, retrieve the class
103+
name in snake_case format, directly call callback methods, and delegate attribute access
104+
to the runner if the attribute does not exist in the Callback class.
105+
106+
The _order is used to decide the order of Callbacks.
107+
108+
"""
74109
_order = 0
75110

76-
def set_runner(self, run):
111+
def set_runner(self, run) -> None:
77112
self.run = run
78113

79-
def __getattr__(self, k):
80-
return getattr(self.run, k)
81-
82114
@property
83115
def name(self):
84116
name = re.sub(r"Callback$", "", self.__class__.__name__)
@@ -90,59 +122,121 @@ def __call__(self, cb_name):
90122
return True
91123
return False
92124

125+
def __getattr__(self, k):
126+
return getattr(self.run, k)
127+
93128

94129
class ParamScheduler(Callback):
130+
"""
131+
Manages scheduling of parameter adjustments over the course of training.
132+
"""
95133
_order = 1
96134

97135
def __init__(self, pname, sched_funcs):
98136
self.pname, self.sched_funcs = pname, sched_funcs
99137

100138
def begin_fit(self):
139+
"""
140+
Prepare the scheduler at the start of the fitting process.
141+
This method ensures that sched_funcs is a list with one function per parameter group.
142+
"""
101143
if not isinstance(self.sched_funcs, (list, tuple)):
102144
self.sched_funcs = [self.sched_funcs] * len(self.opt.param_groups)
103145

104146
def set_param(self):
147+
"""
148+
Adjust the parameter value for each parameter group based on the scheduling function.
149+
Ensures the number of scheduling functions matches the number of parameter groups.
150+
"""
105151
assert len(self.opt.param_groups) == len(self.sched_funcs)
106152
for pg, f in zip(self.opt.param_groups, self.sched_funcs):
107153
pg[self.pname] = f(self.n_epochs / self.epochs)
108154

109155
def begin_batch(self):
156+
"""
157+
Apply parameter adjustments at the beginning of each batch if in training mode.
158+
"""
110159
if self.in_train:
111160
self.set_param()
112161

113-
114162
class Recorder(Callback):
163+
"""
164+
Recorder is a callback class used to record learning rates and losses during the training process.
165+
"""
115166
def begin_fit(self):
167+
"""
168+
Initializes attributes necessary for the fitting process.
169+
170+
Sets up learning rates and losses storage.
171+
172+
Attributes:
173+
self.lrs (list): A list of lists, where each inner list will hold learning rates for a parameter group.
174+
self.losses (list): An empty list to store loss values during the fitting process.
175+
"""
116176
self.lrs = [[] for _ in self.opt.param_groups]
117177
self.losses = []
118178

119179
def after_batch(self):
180+
"""
181+
Handles operations to execute after each training batch.
182+
183+
Modifies the learning rate for each parameter group in the optimizer
184+
and appends the current learning rate and loss to the corresponding lists.
185+
186+
"""
120187
if not self.in_train:
121188
return
122189
for pg, lr in zip(self.opt.param_groups, self.lrs):
123190
lr.append(pg["lr"])
124191
self.losses.append(self.loss.detach().cpu())
125192

126193
def plot_lr(self, pgid=-1):
194+
"""
195+
Plots the learning rate for a given parameter group.
196+
"""
127197
plt.plot(self.lrs[pgid])
128198

129199
def plot_loss(self, skip_last=0):
200+
"""
201+
Plots the loss for a given parameter group.
202+
"""
130203
plt.plot(self.losses[: len(self.losses) - skip_last])
131204

132205
def plot(self, skip_last=0, pgid=-1):
206+
"""
207+
Generates a plot of the loss values against the learning rates.
208+
"""
133209
losses = [o.item() for o in self.losses]
134210
lrs = self.lrs[pgid]
135211
n = len(losses) - skip_last
136212
plt.xscale("log")
137213
plt.plot(lrs[:n], losses[:n])
138214

139215

140-
def camel2snake(name):
141-
s1 = re.sub(_camel_re1, r"\1_\2", name)
142-
return re.sub(_camel_re2, r"\1_\2", s1).lower()
216+
class TrainEvalCallback(Callback):
217+
"""
218+
TrainEvalCallback class is a custom callback used during the training
219+
and validation phases of a machine learning model to perform specific
220+
actions at the beginning and after certain events.
143221
222+
Methods:
144223
145-
class TrainEvalCallback(Callback):
224+
begin_fit():
225+
Initialize the number of epochs and iteration counts at the start
226+
of the fitting process.
227+
228+
after_batch():
229+
Update the epoch and iteration counts after each batch during
230+
training.
231+
232+
begin_epoch():
233+
Set the current epoch, switch the model to training mode, and
234+
indicate that the model is in training.
235+
236+
begin_validate():
237+
Switch the model to evaluation mode and indicate that the model
238+
is in validation.
239+
"""
146240
def begin_fit(self):
147241
self.run.n_epochs = 0
148242
self.run.n_iter = 0
@@ -176,6 +270,32 @@ class CancelBatchException(Exception):
176270

177271

178272
class AvgStats:
273+
"""
274+
AvgStats class is used to track and accumulate average statistics (like loss and other metrics) during training and validation phases.
275+
276+
Attributes:
277+
metrics (list): A list of metric functions to be tracked.
278+
in_train (bool): A flag to indicate if the statistics are for the training phase.
279+
280+
Methods:
281+
__init__(metrics, in_train):
282+
Initializes the AvgStats with metrics and in_train flag.
283+
284+
reset():
285+
Resets the accumulated statistics.
286+
287+
all_stats:
288+
Property that returns all accumulated statistics including loss and metrics.
289+
290+
avg_stats:
291+
Property that returns the average of the accumulated statistics.
292+
293+
accumulate(run):
294+
Accumulates the statistics using the data from the given run.
295+
296+
__repr__():
297+
Returns a string representation of the average statistics.
298+
"""
179299
def __init__(self, metrics, in_train):
180300
self.metrics, self.in_train = listify(metrics), in_train
181301

@@ -191,20 +311,32 @@ def all_stats(self):
191311
def avg_stats(self):
192312
return [o / self.count for o in self.all_stats]
193313

194-
def __repr__(self):
195-
if not self.count:
196-
return ""
197-
return f"{'train' if self.in_train else 'valid'}: {self.avg_stats}"
198-
199314
def accumulate(self, run):
200315
bn = run.xb.shape[0]
201316
self.tot_loss += run.loss * bn
202317
self.count += bn
203318
for i, m in enumerate(self.metrics):
204319
self.tot_mets[i] += m(run.pred, run.yb) * bn
205320

321+
def __repr__(self):
322+
if not self.count:
323+
return ""
324+
return f"{'train' if self.in_train else 'valid'}: {self.avg_stats}"
325+
206326

207327
class AvgStatsCallBack(Callback):
328+
"""
329+
AvgStatsCallBack class is a custom callback used to track and print average statistics for training and validation phases during the training loop.
330+
331+
Arguments:
332+
metrics: A list of metric functions to evaluate during training and validation.
333+
334+
Methods:
335+
__init__: Initializes the callback with given metrics and sets up AvgStats objects for both training and validation phases.
336+
begin_epoch: Resets the statistics at the beginning of each epoch.
337+
after_loss: Accumulates the metrics after computing the loss, differentiating between training and validation phases.
338+
after_epoch: Prints the accumulated statistics for both training and validation phases after each epoch.
339+
"""
208340
def __init__(self, metrics):
209341
self.train_stats, self.valid_stats = AvgStats(metrics, True), AvgStats(
210342
metrics, False

0 commit comments

Comments
 (0)