-
Notifications
You must be signed in to change notification settings - Fork 0
/
config.yaml
86 lines (84 loc) · 2.69 KB
/
config.yaml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
defaults:
- model: nrms
- embedding_layer: word_embedding
- hparams_model: ${model}
- hparams_embedding: ${embedding_layer}
- optim: ${embedding_layer}
- data_path: ${embedding_layer}
- dataset: ${embedding_layer}
- deterministic: disabled
- optional debug_run: disabled
- _self_
default_choices: ${hydra:runtime.choices}
dataset_type: small
num_workers: 4
data_path:
train:
samples: ${oc.env:WORKDIR}/dataset/mind/${dataset_type}/processed/train/behaviors.pkl
articles: ${oc.env:WORKDIR}/dataset/mind/${dataset_type}/processed/train/articles.pkl
valid:
samples: ${oc.env:WORKDIR}/dataset/mind/${dataset_type}/processed/valid/behaviors.pkl
articles: ${oc.env:WORKDIR}/dataset/mind/${dataset_type}/processed/valid/articles.pkl
category: ${oc.env:WORKDIR}/dataset/mind/${dataset_type}/processed/categories.txt
subcategory: ${oc.env:WORKDIR}/dataset/mind/${dataset_type}/processed/subcategories.txt
hparams:
# value of ${default_choices} becomes invalid under multiran setting (hydra v1.1): https://github.com/facebookresearch/hydra/issues/1882
model: ${default_choices.model}
embedding_layer: ${default_choices.embedding_layer}
article_attributes:
- title
- body
- category
- subcategory
embedding_dim: 400
attn_hidden_dim: 200
category_hidden_dim: 100
p_dropout: 0.2
n_categories: ${line_count:${data_path.category}}
n_subcategories: ${line_count:${data_path.subcategory}}
data_shuffle_seed: 0
train_seed: 0
n_negatives: 4
n_epochs: 3
max_history_length: 50
optim: ${optim}
data_loader:
train:
_target_: torch.utils.data.DataLoader
batch_size: ${hparams.batch_size.train}
shuffle: True
num_workers: ${num_workers}
collate_fn:
_target_: mind_recommenders_pytorch.train.data.input_collator.InputCollator
valid:
_target_: torch.utils.data.DataLoader
batch_size: ${hparams.batch_size.valid}
shuffle: False
num_workers: ${num_workers}
collate_fn:
_target_: mind_recommenders_pytorch.train.data.input_collator.InputCollator
pl_model:
_target_: mind_recommenders_pytorch.train.pl_model.PlModel
trainer:
_target_: pytorch_lightning.Trainer
default_root_dir: "./"
max_epochs: ${hparams.n_epochs}
gpus:
- 0
detect_anomaly: False
precision: ${hparams.precision}
callbacks:
- _target_: pytorch_lightning.callbacks.ModelCheckpoint
monitor: val_auc
filename: '{epoch:04d}-{step:04d}-{val_auc:.3f}'
save_top_k: 10
logger:
_target_: pytorch_lightning.loggers.TensorBoardLogger
save_dir: ${trainer.default_root_dir}
name: train_logs
hydra:
run:
dir: ./logs/${hydra.job.override_dirname}
sweep:
dir: ./logs/
subdir: ${hydra.job.override_dirname}