Skip to content

Commit

Permalink
more careful with computing graph
Browse files Browse the repository at this point in the history
  • Loading branch information
minhuanli committed May 30, 2024
1 parent 1766872 commit c397c6c
Showing 1 changed file with 24 additions and 1 deletion.
25 changes: 24 additions & 1 deletion SFC_Torch/Fmodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -864,6 +864,8 @@ def _get_scales_lbfgs_LS(

if initialize:
self.init_scales(requires_grad=True)
else:
self.unfreeze_scales()

def closure():
Fmodel = self.calc_ftotal(scale_mode=True)
Expand Down Expand Up @@ -912,6 +914,8 @@ def _get_scales_lbfgs_r(

if initialize:
self.init_scales(requires_grad=True)
else:
self.unfreeze_scales()

def closure():
Fmodel = self.calc_ftotal(scale_mode=True)
Expand Down Expand Up @@ -960,6 +964,23 @@ def get_scales_lbfgs(
):
self._get_scales_lbfgs_LS(ls_steps, ls_lr, verbose, initialize)
self._get_scales_lbfgs_r(r_steps, r_lr, verbose, initialize=False)
self.freeze_scales()

def freeze_scales(self):
"""
Do not require grad on scales
"""
self.kmasks = [kmask.requires_grad_(False) for kmask in self.kmasks]
self.kisos = [kiso.requires_grad_(False) for kiso in self.kisos]
self.uanisos = [uaniso.requires_grad_(False) for uaniso in self.uanisos]

def unfreeze_scales(self):
"""
Do not require grad on scales
"""
self.kmasks = [kmask.requires_grad_(True) for kmask in self.kmasks]
self.kisos = [kiso.requires_grad_(True) for kiso in self.kisos]
self.uanisos = [uaniso.requires_grad_(True) for uaniso in self.uanisos]

def get_scales_adam(
self,
Expand Down Expand Up @@ -1009,6 +1030,8 @@ def adam_stepopt(sub_boolean_mask):

if initialize:
self.init_scales(requires_grad=True)
else:
self.unfreeze_scales()

for bin_i in range(self.n_bins):
index_i = (self.bins == bin_i) & (~self.Outlier)
Expand All @@ -1022,7 +1045,7 @@ def adam_stepopt(sub_boolean_mask):
sub_ratio=sub_ratio,
verbose=verbose,
)

self.freeze_scales()
Fmodel = self.calc_ftotal()
self.r_work, self.r_free = self.get_rfactors(ftotal=Fmodel)

Expand Down

0 comments on commit c397c6c

Please sign in to comment.