diff --git a/pytorch/hyperbolic_parameter.py b/pytorch/hyperbolic_parameter.py index e19a2a9e..41224377 100644 --- a/pytorch/hyperbolic_parameter.py +++ b/pytorch/hyperbolic_parameter.py @@ -122,7 +122,7 @@ def exp(self, lr): cosh = torch.cosh(n) cosh[mask] = 1.0 sinh = torch.sinh(n) - sinh[mask] = 0.0 + sinh[mask] = 1.0 n[mask] = 1.0 e = cosh*x + sinh/n*v # assert torch.all(-HyperboloidParameter.dot_h(e,e) >= 0), torch.min(-HyperboloidParameter.dot_h(e,e))