diff --git a/head/metrics.py b/head/metrics.py index 466dc83..4ee714b 100644 --- a/head/metrics.py +++ b/head/metrics.py @@ -79,10 +79,7 @@ def __init__(self, in_features, out_features, s = 64.0, m = 0.50, easy_margin = def forward(self, embbedings, label): embbedings = l2_norm(embbedings, axis = 1) kernel_norm = l2_norm(self.kernel, axis = 0) - cos_theta = torch.mm(embbedings, kernel_norm) - cos_theta = cos_theta.clamp(-1, 1) # for numerical stability - with torch.no_grad(): - origin_cos = cos_theta.clone() + cos_theta = torch.mm(embbedings, kernel_norm).clamp_(-1, 1) # for numerical stability target_logit = cos_theta[torch.arange(0, embbedings.size(0)), label].view(-1, 1) sin_theta = torch.sqrt(1.0 - torch.pow(target_logit, 2)) @@ -341,10 +338,7 @@ def __init__(self, in_features, out_features, m = 0.5, s = 64.): def forward(self, embbedings, label): embbedings = l2_norm(embbedings, axis = 1) kernel_norm = l2_norm(self.kernel, axis = 0) - cos_theta = torch.mm(embbedings, kernel_norm) - cos_theta = cos_theta.clamp(-1, 1) # for numerical stability - with torch.no_grad(): - origin_cos = cos_theta.clone() + cos_theta = torch.mm(embbedings, kernel_norm).clamp_(-1, 1) # for numerical stability target_logit = cos_theta[torch.arange(0, embbedings.size(0)), label].view(-1, 1) sin_theta = torch.sqrt(1.0 - torch.pow(target_logit, 2)) @@ -584,4 +578,4 @@ def forward(self, input, label): output = torch.logsumexp(logit_n, dim=1) + torch.logsumexp(logit_p, dim=1) - return output \ No newline at end of file + return output