@@ -375,61 +375,68 @@ def forward(self, inputs, targets, current_epoch):
375375
376376 return loss
377377
378+
379+
378380class UNITEI (nn .Module ):
379381 """
380- Implementation of UNITE-I loss
382+ UNITE-I loss (pattern-aware & noise-resilient) for mini-batch training.
381383 """
382- def __init__ (self ,
383- gamma : float = 10 ,
384- sigma : float = 1000.0 ,
385- # penalty_weight: float = 1.0,
386- # tau_init: float = 1e-3
387- ):
388- super (UNITEI , self ).__init__ ()
389-
390- # self.num_triples = num_triples
391- self .gamma = gamma
392- self .sigma = sigma
393- # self.penalty_weight = penalty_weight
394- # self.tau = tau_init
395- # self.eps = 1e-12
396- self .relu = nn .ReLU ()
397-
398- def forward (self , pred , target , current_epoch = None ):
399- """
400- This function calculates the UNITE loss.
401-
402- Agrs:
403- pred: The output logits from the KGE model.
404- target: The target labesl
405-
406- """
407- target = target .float ()
408-
409- #Separate positive and negative logits
410- l_pos = pred [target == 1.0 ]
411- l_neg = pred [target == 0.0 ]
412-
413- loss_pos_total = torch .tensor (0.0 , device = pred .device )
414- loss_neg_total = torch .tensor (0.0 , device = pred .device )
415-
416- if l_pos .numel () > 0 :
417- tau_pos_star = self .relu (self .gamma - l_pos )
418- loss_pos = self .sigma * (tau_pos_star ** 2 )
419- loss_pos_total = torch .sum (loss_pos )
420-
421- if l_neg .numel () > 0 :
422- tau_neg_star = self .relu (l_neg - self .gamma )
423- loss_neg = self .sigma * (tau_neg_star ** 2 )
424- loss_neg_total = torch .sum (loss_neg )
425-
426- total_loss = loss_pos_total + loss_neg_total
427-
428- if pred .numel () == 0 :
429- return torch .tensor (0.0 , device = pred .device , requires_grad = True )
430-
431- return total_loss / pred .numel ()
432-
433-
434-
435384
385+ def __init__ (
386+ self ,
387+ gamma : float = 5 ,
388+ sigma : float = 1000 ,
389+ lambda_q : float = 0.3 , # Regulaization term
390+ clamp_scores : bool = True ,
391+ use_tanh_scale : bool = True ,
392+ ):
393+ super ().__init__ ()
394+ self .gamma = float (gamma )
395+ self .sigma = float (sigma )
396+ self .lambda_q = float (lambda_q )
397+ self .clamp_scores = clamp_scores
398+ self .use_tanh_scale = use_tanh_scale
399+ self .eps = 1e-12
400+
401+ def _normalize_scores (self , pred : torch .Tensor ) -> torch .Tensor :
402+ # To keep gamma meaningful and gradients don’t die or explode.
403+ x = pred
404+ if self .use_tanh_scale :
405+ # squashes to (-1,1), then scales to (-γ, γ)
406+ x = torch .tanh (x ) * self .gamma
407+ if self .clamp_scores :
408+ x = torch .clamp (x , min = - self .gamma , max = self .gamma )
409+ return x
410+
411+ def forward (self , pred : torch .Tensor , target : torch .Tensor , current_epoch = None ):
412+
413+ target = target .float ()
414+ pred = self ._normalize_scores (pred )
415+
416+ tau_pos = F .relu (self .gamma - pred ) # only where target==1 used
417+ tau_neg = F .relu (pred - self .gamma ) # only where target==0 used
418+
419+ loss_pos = self .sigma * (tau_pos ** 2 )
420+ loss_neg = self .sigma * (tau_neg ** 2 )
421+
422+ # Constraint penalties with Q(x) = sigmoid(x)
423+ # Pos: Q(gamma - f) >= Q(tau) → violation = ReLU(Q(tau) - Q(gamma - f)) where f=pred
424+ # Neg: Q(f - gamma) >= Q(tau) → violation = ReLU(Q(tau) - Q(f - tau))
425+ Q_tau_pos = torch .sigmoid (tau_pos )
426+ Q_tau_neg = torch .sigmoid (tau_neg )
427+ Q_pos = torch .sigmoid (self .gamma - pred )
428+ Q_neg = torch .sigmoid (pred - self .gamma )
429+
430+ pos_violation = F .relu (Q_tau_pos - Q_pos )
431+ neg_violation = F .relu (Q_tau_neg - Q_neg )
432+
433+ # Mask by labels and combine
434+ pos_mask = target
435+ neg_mask = 1.0 - target
436+
437+ loss = (
438+ pos_mask * (loss_pos + self .lambda_q * pos_violation )
439+ + neg_mask * (loss_neg + self .lambda_q * neg_violation )
440+ )
441+
442+ return loss .mean ()
0 commit comments