-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathestimator.py
332 lines (306 loc) · 11.7 KB
/
estimator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.
from typing import List, Optional, Iterable, Dict, Any
import torch
import pytorch_lightning as pl
from gluonts.core.component import validated
from gluonts.dataset.common import Dataset
from gluonts.dataset.field_names import FieldName
from gluonts.dataset.loader import as_stacked_batches
from gluonts.dataset.stat import calculate_dataset_statistics
from gluonts.itertools import Cyclic
from gluonts.torch.modules.loss import DistributionLoss, NegativeLogLikelihood
from gluonts.transform import (
Transformation,
AddObservedValuesIndicator,
InstanceSampler,
InstanceSplitter,
ValidationSplitSampler,
TestSplitSampler,
ExpectedNumInstanceSampler,
RemoveFields,
SetField,
AddTimeFeatures,
AddAgeFeature,
VstackFeatures,
)
from gluonts.time_feature import TimeFeature, time_features_from_frequency_str
from gluonts.torch.model.estimator import PyTorchLightningEstimator
from gluonts.torch.model.predictor import PyTorchPredictor
from gluonts.torch.distributions import DistributionOutput, StudentTOutput
from .lightning_module import TSMixerLightningModule
PREDICTION_INPUT_NAMES = [
"feat_static_cat",
"feat_static_real",
"past_time_feat",
"past_target",
"past_observed_values",
"future_time_feat",
]
TRAINING_INPUT_NAMES = PREDICTION_INPUT_NAMES + [
"future_target",
"future_observed_values",
]
class TSMixerEstimator(PyTorchLightningEstimator):
"""
An estimator training a TSMixer model for forecasting.
This class is uses the model defined in ``TSMixerModel``,
and wraps it into a ``TSMixerLightningModule`` for training
purposes: training is performed using PyTorch Lightning's ``pl.Trainer``
class.
Parameters
----------
prediction_length
Length of the prediction horizon.
context_length
Number of time steps prior to prediction time that the model
takes as inputs (default: ``10 * prediction_length``).
lr
Learning rate (default: ``1e-3``).
weight_decay
Weight decay regularization parameter (default: ``1e-8``).
distr_output
Distribution to use to evaluate observations and sample predictions
(default: StudentTOutput()).
loss
Loss to be optimized during training
(default: ``NegativeLogLikelihood()``).
batch_norm
Whether to apply batch normalization.
batch_size
The size of the batches to be used for training (default: 32).
num_batches_per_epoch
Number of batches to be processed in each training epoch
(default: 50).
trainer_kwargs
Additional arguments to provide to ``pl.Trainer`` for construction.
train_sampler
Controls the sampling of windows during training.
validation_sampler
Controls the sampling of windows during validation.
"""
@validated()
def __init__(
self,
freq: str,
prediction_length: int,
context_length: Optional[int] = None,
input_size: int = 1,
depth: int = 1,
dim: int = 32,
expansion_factor: int = 4,
dropout: float = 0.1,
batch_norm: bool = True,
scaling: Optional[str] = "mean",
num_feat_dynamic_real: int = 0,
num_feat_static_cat: int = 0,
num_feat_static_real: int = 0,
cardinality: Optional[List[int]] = None,
embedding_dimension: Optional[List[int]] = None,
time_features: Optional[List[TimeFeature]] = None,
lr: float = 1e-3,
weight_decay: float = 1e-8,
distr_output: DistributionOutput = StudentTOutput(),
loss: DistributionLoss = NegativeLogLikelihood(),
num_parallel_samples: int = 100,
batch_size: int = 32,
num_batches_per_epoch: int = 50,
trainer_kwargs: Optional[Dict[str, Any]] = None,
train_sampler: Optional[InstanceSampler] = None,
validation_sampler: Optional[InstanceSampler] = None,
) -> None:
default_trainer_kwargs = {
"max_epochs": 100,
}
if trainer_kwargs is not None:
default_trainer_kwargs.update(trainer_kwargs)
super().__init__(trainer_kwargs=default_trainer_kwargs)
self.scaling = scaling
self.freq = freq
self.input_size = input_size
self.prediction_length = prediction_length
self.context_length = context_length or 10 * prediction_length
self.num_feat_dynamic_real = num_feat_dynamic_real
self.num_feat_static_cat = num_feat_static_cat
self.num_feat_static_real = num_feat_static_real
self.cardinality = (
cardinality if cardinality and num_feat_static_cat > 0 else [1]
)
self.embedding_dimension = embedding_dimension
self.time_features = (
time_features
if time_features is not None
else time_features_from_frequency_str(self.freq)
)
# TODO find way to enforce same defaults to network and estimator
# somehow
self.depth = depth
self.dim = dim
self.expansion_factor = expansion_factor
self.dropout = dropout
self.batch_norm = batch_norm
self.lr = lr
self.weight_decay = weight_decay
self.distr_output = distr_output
self.num_parallel_samples = num_parallel_samples
self.loss = loss
self.batch_size = batch_size
self.num_batches_per_epoch = num_batches_per_epoch
self.train_sampler = train_sampler or ExpectedNumInstanceSampler(
num_instances=1.0, min_future=prediction_length
)
self.validation_sampler = validation_sampler or ValidationSplitSampler(
min_future=prediction_length
)
@classmethod
def derive_auto_fields(cls, train_iter):
stats = calculate_dataset_statistics(train_iter)
return {
"num_feat_dynamic_real": stats.num_feat_dynamic_real,
"num_feat_static_cat": len(stats.feat_static_cat),
"cardinality": [len(cats) for cats in stats.feat_static_cat],
}
def create_transformation(self) -> Transformation:
remove_field_names = []
if self.num_feat_static_real == 0:
remove_field_names.append(FieldName.FEAT_STATIC_REAL)
if self.num_feat_dynamic_real == 0:
remove_field_names.append(FieldName.FEAT_DYNAMIC_REAL)
return (
RemoveFields(field_names=remove_field_names)
+ (
SetField(output_field=FieldName.FEAT_STATIC_CAT, value=[0])
if not self.num_feat_static_cat > 0
else []
)
+ (
SetField(output_field=FieldName.FEAT_STATIC_REAL, value=[0.0])
if not self.num_feat_static_real > 0
else []
)
+ AddTimeFeatures(
start_field=FieldName.START,
target_field=FieldName.TARGET,
output_field=FieldName.FEAT_TIME,
time_features=self.time_features,
pred_length=self.prediction_length,
)
+ AddAgeFeature(
target_field=FieldName.TARGET,
output_field=FieldName.FEAT_AGE,
pred_length=self.prediction_length,
log_scale=True,
)
+ VstackFeatures(
output_field=FieldName.FEAT_TIME,
input_fields=[FieldName.FEAT_TIME, FieldName.FEAT_AGE]
+ (
[FieldName.FEAT_DYNAMIC_REAL]
if self.num_feat_dynamic_real > 0
else []
),
)
+ AddObservedValuesIndicator(
target_field=FieldName.TARGET,
output_field=FieldName.OBSERVED_VALUES,
)
)
def create_lightning_module(self) -> pl.LightningModule:
return TSMixerLightningModule(
loss=self.loss,
lr=self.lr,
weight_decay=self.weight_decay,
model_kwargs={
"input_size": self.input_size,
"prediction_length": self.prediction_length,
"context_length": self.context_length,
"depth": self.depth,
"dim": self.dim,
"expansion_factor": self.expansion_factor,
"dropout": self.dropout,
"batch_norm": self.batch_norm,
"num_feat_dynamic_real": 1
+ self.num_feat_dynamic_real
+ len(self.time_features),
"num_feat_static_cat": self.num_feat_static_cat,
"scaling": self.scaling,
"distr_output": self.distr_output,
"num_parallel_samples": self.num_parallel_samples,
},
)
def _create_instance_splitter(self, module: TSMixerLightningModule, mode: str):
assert mode in ["training", "validation", "test"]
instance_sampler = {
"training": self.train_sampler,
"validation": self.validation_sampler,
"test": TestSplitSampler(),
}[mode]
return InstanceSplitter(
target_field=FieldName.TARGET,
is_pad_field=FieldName.IS_PAD,
start_field=FieldName.START,
forecast_start_field=FieldName.FORECAST_START,
instance_sampler=instance_sampler,
past_length=self.context_length,
future_length=self.prediction_length,
time_series_fields=[FieldName.FEAT_TIME, FieldName.OBSERVED_VALUES],
dummy_value=self.distr_output.value_in_support,
)
def create_training_data_loader(
self,
data: Dataset,
module: TSMixerLightningModule,
shuffle_buffer_length: Optional[int] = None,
**kwargs,
) -> Iterable:
data = Cyclic(data).stream()
instances = self._create_instance_splitter(module, "training").apply(
data, is_train=True
)
return as_stacked_batches(
instances,
batch_size=self.batch_size,
shuffle_buffer_length=shuffle_buffer_length,
field_names=TRAINING_INPUT_NAMES,
output_type=torch.tensor,
num_batches_per_epoch=self.num_batches_per_epoch,
)
def create_validation_data_loader(
self,
data: Dataset,
module: TSMixerLightningModule,
**kwargs,
) -> Iterable:
instances = self._create_instance_splitter(module, "validation").apply(
data, is_train=True
)
return as_stacked_batches(
instances,
batch_size=self.batch_size,
field_names=TRAINING_INPUT_NAMES,
output_type=torch.tensor,
)
def create_predictor(
self,
transformation: Transformation,
module,
) -> PyTorchPredictor:
prediction_splitter = self._create_instance_splitter(module, "test")
return PyTorchPredictor(
input_transform=transformation + prediction_splitter,
input_names=PREDICTION_INPUT_NAMES,
prediction_net=module,
batch_size=self.batch_size,
prediction_length=self.prediction_length,
device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
)