-
Notifications
You must be signed in to change notification settings - Fork 18
/
Copy pathpatchtst.yaml
48 lines (48 loc) · 1.08 KB
/
patchtst.yaml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
# lightning==2.3.0.dev0
seed_everything: 1
trainer:
accelerator: gpu
devices: 1
strategy: auto
max_epochs: 50
use_distributed_sampler: false
limit_train_batches: 100
log_every_n_steps: 1
default_root_dir: ./results
accumulate_grad_batches: 1
model:
forecaster:
class_path: probts.model.forecaster.point_forecaster.PatchTST
init_args:
stride: 3
patch_len: 6
dropout: 0.1
f_hidden_size: 32
n_layers: 3
n_heads: 8
fc_dropout: 0.2
head_dropout: 0
individual: false
optimizer_config:
class_name: torch.optim.Adam
init_args:
weight_decay: 0
lr_scheduler_config:
class_name: torch.optim.lr_scheduler.OneCycleLR
init_args:
max_lr: 0.0001
steps_per_epoch: 100
pct_start: 0.3
epochs: 50
learning_rate: 0.0001
quantiles_num: 20
data:
data_manager:
class_path: probts.data.data_manager.DataManager
init_args:
dataset: exchange_rate_nips
split_val: true
scaler: standard # identity, standard, temporal
batch_size: 64
test_batch_size: 64
num_workers: 8