Skip to content

Commit

Permalink
grad norm
Browse files Browse the repository at this point in the history
  • Loading branch information
CW-Huang committed Sep 3, 2020
1 parent 9417b02 commit 002fe4c
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 0 deletions.
1 change: 1 addition & 0 deletions hparams.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
num_epochs=[10],
early_stop=[True],
ignore_w=[False],
grad_norm=['inf'],

w_transform=['Standardize'],
y_transform=['Normalize'],
Expand Down
3 changes: 3 additions & 0 deletions models/nonlinear.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def __init__(self, w, t, y, seed=1,
shuffle=True,
early_stop=True,
ignore_w=False,
grad_norm=float('inf'),
w_transform=PlaceHolderTransform,
t_transform=PlaceHolderTransform,
y_transform=PlaceHolderTransform,
Expand All @@ -98,6 +99,7 @@ def __init__(self, w, t, y, seed=1,
self.outcome_max = outcome_max
self.early_stop = early_stop
self.ignore_w = ignore_w
self.grad_norm = grad_norm
self.savepath = savepath

self.dim_w = self.w_transformed.shape[1]
Expand Down Expand Up @@ -172,6 +174,7 @@ def train(self, early_stop=None, print_=print):
loss, loss_t, loss_y = self._get_loss(w, t, y)
# TODO: learning rate can be separately adjusted by weighting the losses here
loss.backward()
torch.nn.utils.clip_grad_norm(chain(*[net.parameters() for net in self.networks]), self.grad_norm)
self.optim.step()

c += 1
Expand Down
2 changes: 2 additions & 0 deletions train_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ def main(args):
seed=args.seed,
early_stop=args.early_stop,
ignore_w=args.ignore_w,
grad_norm=args.grad_norm,
w_transform=w_transform, y_transform=y_transform, # TODO set more args
savepath=os.path.join(args.saveroot, 'model.pt'))
# TODO GPU support
Expand Down Expand Up @@ -197,6 +198,7 @@ def main(args):
parser.add_argument('--num_epochs', type=int, default=100)
parser.add_argument('--early_stop', type=eval, default=True, choices=[True, False])
parser.add_argument('--ignore_w', type=eval, default=False, choices=[True, False])
parser.add_argument('--grad_norm', type=float, default=float('inf'))

parser.add_argument('--w_transform', type=str, default='Standardize',
choices=preprocess.Preprocess.prep_names)
Expand Down

0 comments on commit 002fe4c

Please sign in to comment.