diff --git a/src/zeroband/optimizers/__init__.py b/src/zeroband/optimizers/__init__.py index ec3419a1..ccd50bca 100644 --- a/src/zeroband/optimizers/__init__.py +++ b/src/zeroband/optimizers/__init__.py @@ -43,9 +43,9 @@ def get_optimizer(params: list[torch.nn.Parameter], config: OptimizersConfig) -> lr=config.lr, betas=(config.betas1, config.betas2), epsilon=1e-12, - weight_decay=1e-05, - max_preconditioner_dim=8192, - precondition_frequency=100, + weight_decay=config.weight_decay, + max_preconditioner_dim=config.max_preconditioner_dim, + precondition_frequency=config.precondition_frequency, use_decoupled_weight_decay=True, # This can also be set to `QREigenvalueCorrectionConfig` which is less expensive # and might therefore allow for a smaller `precondition_frequency`.