Skip to content

Commit c1eed9d

Browse files
committed
Try unscale
1 parent dd69813 commit c1eed9d

File tree

1 file changed

+10
-5
lines changed

1 file changed

+10
-5
lines changed

kronfluence/module/tracker/base.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -60,13 +60,18 @@ def _preprocess_gradient(self, output_gradient: torch.Tensor, target_dtype: torc
6060
torch.Tensor:
6161
The preprocessed gradient.
6262
"""
63-
original_dtype = output_gradient.dtype
63+
# original_dtype = output_gradient.dtype
64+
# output_gradient = output_gradient.to(dtype=target_dtype)
65+
# if self.module.gradient_scale != 1.0:
66+
# if original_dtype != target_dtype:
67+
# output_gradient.mul_(self.module.gradient_scale)
68+
# else:
69+
# output_gradient = output_gradient * self.module.gradient_scale
70+
# return output_gradient
6471
output_gradient = output_gradient.to(dtype=target_dtype)
6572
if self.module.gradient_scale != 1.0:
66-
if original_dtype != target_dtype:
67-
output_gradient.mul_(self.module.gradient_scale)
68-
else:
69-
output_gradient = output_gradient * self.module.gradient_scale
73+
output_gradient = output_gradient * self.module.gradient_scale
74+
output_gradient = output_gradient.to(dtype=target_dtype)
7075
return output_gradient
7176

7277
def register_hooks(self) -> None:

0 commit comments

Comments
 (0)