-
Notifications
You must be signed in to change notification settings - Fork 0
/
config.py
36 lines (26 loc) · 1.06 KB
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
from typing import Union
from pydantic import BaseModel, Field, model_validator
class ModelConfig(BaseModel):
embed_dims: int = Field(default=128)
hidden_dims: int = Field(default=256)
num_heads: int = Field(default=2)
max_sequence_length: int = Field(default=100)
num_layers: int = Field(default=2)
dropout: float = Field(default=0.1)
num_components: int = Field(default=1)
elu_alpha: float = Field(default=3.0)
sigma_min: float = Field(default=0.008) # 0.004 == 1 pixel
class TrainingConfig(BaseModel):
# data
num_epochs: int = Field(default=100)
batch_size: int = Field(default=128)
max_sequence_length: int = Field(default=100)
data_sparsity: int = Field(default=1)
aug_scale_factor: float = Field(default=0.05)
# optimizer
learning_rate: float = Field(default=3e-04)
gradient_clip: Union[float, None] = Field(default=1.0)
# logging
wandb_mode: str = Field(default="online")
log_frequency: int = Field(default=100)
save_parent_dir: Union[str, None] = Field(default="checkpoints")