From c397c6ce8df6bbb26ead5aba18e908d75ce72745 Mon Sep 17 00:00:00 2001 From: minhuanli Date: Thu, 30 May 2024 15:43:45 -0400 Subject: [PATCH] more careful with computing graph --- SFC_Torch/Fmodel.py | 25 ++++++++++++++++++++++++- 1 file changed, 24 insertions(+), 1 deletion(-) diff --git a/SFC_Torch/Fmodel.py b/SFC_Torch/Fmodel.py index 99e6d0f..67714c9 100644 --- a/SFC_Torch/Fmodel.py +++ b/SFC_Torch/Fmodel.py @@ -864,6 +864,8 @@ def _get_scales_lbfgs_LS( if initialize: self.init_scales(requires_grad=True) + else: + self.unfreeze_scales() def closure(): Fmodel = self.calc_ftotal(scale_mode=True) @@ -912,6 +914,8 @@ def _get_scales_lbfgs_r( if initialize: self.init_scales(requires_grad=True) + else: + self.unfreeze_scales() def closure(): Fmodel = self.calc_ftotal(scale_mode=True) @@ -960,6 +964,23 @@ def get_scales_lbfgs( ): self._get_scales_lbfgs_LS(ls_steps, ls_lr, verbose, initialize) self._get_scales_lbfgs_r(r_steps, r_lr, verbose, initialize=False) + self.freeze_scales() + + def freeze_scales(self): + """ + Do not require grad on scales + """ + self.kmasks = [kmask.requires_grad_(False) for kmask in self.kmasks] + self.kisos = [kiso.requires_grad_(False) for kiso in self.kisos] + self.uanisos = [uaniso.requires_grad_(False) for uaniso in self.uanisos] + + def unfreeze_scales(self): + """ + Do not require grad on scales + """ + self.kmasks = [kmask.requires_grad_(True) for kmask in self.kmasks] + self.kisos = [kiso.requires_grad_(True) for kiso in self.kisos] + self.uanisos = [uaniso.requires_grad_(True) for uaniso in self.uanisos] def get_scales_adam( self, @@ -1009,6 +1030,8 @@ def adam_stepopt(sub_boolean_mask): if initialize: self.init_scales(requires_grad=True) + else: + self.unfreeze_scales() for bin_i in range(self.n_bins): index_i = (self.bins == bin_i) & (~self.Outlier) @@ -1022,7 +1045,7 @@ def adam_stepopt(sub_boolean_mask): sub_ratio=sub_ratio, verbose=verbose, ) - + self.freeze_scales() Fmodel = self.calc_ftotal() self.r_work, self.r_free = self.get_rfactors(ftotal=Fmodel)