-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathlightning_module.py
120 lines (104 loc) · 3.86 KB
/
lightning_module.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.
import pytorch_lightning as pl
import torch
from gluonts.core.component import validated
from gluonts.torch.modules.loss import DistributionLoss, NegativeLogLikelihood
from .module import TSMixerModel
class TSMixerLightningModule(pl.LightningModule):
"""
A ``pl.LightningModule`` class that can be used to train a
``TSMixerModel`` with PyTorch Lightning.
This is a thin layer around a (wrapped) ``TSMixerModel`` object,
that exposes the methods to evaluate training and validation loss.
Parameters
----------
model
``TSMixerModel`` to be trained.
loss
Loss function to be used for training,
default: ``NegativeLogLikelihood()``.
lr
Learning rate, default: ``1e-3``.
weight_decay
Weight decay regularization parameter, default: ``1e-8``.
"""
@validated()
def __init__(
self,
model_kwargs: dict,
loss: DistributionLoss = NegativeLogLikelihood(),
lr: float = 1e-3,
weight_decay: float = 1e-8,
):
super().__init__()
self.save_hyperparameters()
self.model = TSMixerModel(**model_kwargs)
self.loss = loss
self.lr = lr
self.weight_decay = weight_decay
def forward(self, *args, **kwargs):
distr_args, loc, scale = self.model.forward(*args, **kwargs)
distr = self.model.distr_output.distribution(distr_args, loc, scale)
return distr.sample((self.model.num_parallel_samples,)).reshape(
-1,
self.model.num_parallel_samples,
self.model.prediction_length,
self.model.input_size,
)
def _compute_loss(self, batch):
past_target = batch["past_target"]
past_observed_values = batch["past_observed_values"]
target = batch["future_target"]
observed_target = batch["future_observed_values"]
assert past_target.shape[1] == self.model.context_length
assert target.shape[1] == self.model.prediction_length
distr_args, loc, scale = self.model(
past_target=past_target,
past_observed_values=past_observed_values,
past_time_feat=batch["past_time_feat"],
future_time_feat=batch["future_time_feat"],
)
distr = self.model.distr_output.distribution(distr_args, loc, scale)
return (self.loss(distr, target) * observed_target).sum() / torch.maximum(
torch.tensor(1.0), observed_target.sum()
)
def training_step(self, batch, batch_idx: int): # type: ignore
"""
Execute training step.
"""
train_loss = self._compute_loss(batch)
self.log(
"train_loss",
train_loss,
on_epoch=True,
on_step=False,
prog_bar=True,
)
return train_loss
def validation_step(self, batch, batch_idx: int): # type: ignore
"""
Execute validation step.
"""
val_loss = self._compute_loss(batch)
self.log("val_loss", val_loss, on_epoch=True, on_step=False, prog_bar=True)
return val_loss
def configure_optimizers(self):
"""
Returns the optimizer to use.
"""
return torch.optim.Adam(
self.model.parameters(),
lr=self.lr,
weight_decay=self.weight_decay,
)