Skip to content

Commit 629ba61

Browse files
committed
Added only custom_losses.py and base_model.py
1 parent b2896a8 commit 629ba61

File tree

2 files changed

+92
-58
lines changed

2 files changed

+92
-58
lines changed

dicee/losses/custom_losses.py

Lines changed: 61 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -375,61 +375,68 @@ def forward(self, inputs, targets, current_epoch):
375375

376376
return loss
377377

378+
379+
378380
class UNITEI(nn.Module):
379381
"""
380-
Implementation of UNITE-I loss
382+
UNITE-I loss (pattern-aware & noise-resilient) for mini-batch training.
381383
"""
382-
def __init__(self,
383-
gamma: float = 10,
384-
sigma: float = 1000.0,
385-
# penalty_weight: float = 1.0,
386-
# tau_init: float = 1e-3
387-
):
388-
super(UNITEI, self).__init__()
389-
390-
# self.num_triples = num_triples
391-
self.gamma = gamma
392-
self.sigma = sigma
393-
# self.penalty_weight = penalty_weight
394-
# self.tau = tau_init
395-
# self.eps = 1e-12
396-
self.relu = nn.ReLU()
397-
398-
def forward(self, pred, target, current_epoch = None):
399-
"""
400-
This function calculates the UNITE loss.
401-
402-
Agrs:
403-
pred: The output logits from the KGE model.
404-
target: The target labesl
405-
406-
"""
407-
target = target.float()
408-
409-
#Separate positive and negative logits
410-
l_pos = pred[target == 1.0]
411-
l_neg = pred[target == 0.0]
412-
413-
loss_pos_total = torch.tensor(0.0, device=pred.device)
414-
loss_neg_total = torch.tensor(0.0, device=pred.device)
415-
416-
if l_pos.numel() > 0:
417-
tau_pos_star = self.relu(self.gamma - l_pos)
418-
loss_pos = self.sigma * (tau_pos_star ** 2)
419-
loss_pos_total = torch.sum(loss_pos)
420-
421-
if l_neg.numel() > 0:
422-
tau_neg_star = self.relu(l_neg - self.gamma)
423-
loss_neg = self.sigma * (tau_neg_star ** 2)
424-
loss_neg_total = torch.sum(loss_neg)
425-
426-
total_loss = loss_pos_total + loss_neg_total
427-
428-
if pred.numel() == 0:
429-
return torch.tensor(0.0, device=pred.device, requires_grad=True)
430-
431-
return total_loss / pred.numel()
432-
433-
434-
435384

385+
def __init__(
386+
self,
387+
gamma: float = 5,
388+
sigma: float = 1000,
389+
lambda_q: float = 0.3, # Regulaization term
390+
clamp_scores: bool = True,
391+
use_tanh_scale: bool = True,
392+
):
393+
super().__init__()
394+
self.gamma = float(gamma)
395+
self.sigma = float(sigma)
396+
self.lambda_q = float(lambda_q)
397+
self.clamp_scores = clamp_scores
398+
self.use_tanh_scale = use_tanh_scale
399+
self.eps = 1e-12
400+
401+
def _normalize_scores(self, pred: torch.Tensor) -> torch.Tensor:
402+
# To keep gamma meaningful and gradients don’t die or explode.
403+
x = pred
404+
if self.use_tanh_scale:
405+
# squashes to (-1,1), then scales to (-γ, γ)
406+
x = torch.tanh(x) * self.gamma
407+
if self.clamp_scores:
408+
x = torch.clamp(x, min=-self.gamma, max=self.gamma)
409+
return x
410+
411+
def forward(self, pred: torch.Tensor, target: torch.Tensor, current_epoch = None):
412+
413+
target = target.float()
414+
pred = self._normalize_scores(pred)
415+
416+
tau_pos = F.relu(self.gamma - pred) # only where target==1 used
417+
tau_neg = F.relu(pred - self.gamma) # only where target==0 used
418+
419+
loss_pos = self.sigma * (tau_pos ** 2)
420+
loss_neg = self.sigma * (tau_neg ** 2)
421+
422+
# Constraint penalties with Q(x) = sigmoid(x)
423+
# Pos: Q(gamma - f) >= Q(tau) → violation = ReLU(Q(tau) - Q(gamma - f)) where f=pred
424+
# Neg: Q(f - gamma) >= Q(tau) → violation = ReLU(Q(tau) - Q(f - tau))
425+
Q_tau_pos = torch.sigmoid(tau_pos)
426+
Q_tau_neg = torch.sigmoid(tau_neg)
427+
Q_pos = torch.sigmoid(self.gamma - pred)
428+
Q_neg = torch.sigmoid(pred - self.gamma)
429+
430+
pos_violation = F.relu(Q_tau_pos - Q_pos)
431+
neg_violation = F.relu(Q_tau_neg - Q_neg)
432+
433+
# Mask by labels and combine
434+
pos_mask = target
435+
neg_mask = 1.0 - target
436+
437+
loss = (
438+
pos_mask * (loss_pos + self.lambda_q * pos_violation)
439+
+ neg_mask * (loss_neg + self.lambda_q * neg_violation)
440+
)
441+
442+
return loss.mean()

dicee/models/base_model.py

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,12 +43,39 @@ def training_step(self, batch, batch_idx=None):
4343
# Default
4444
x_batch, y_batch = batch
4545
yhat_batch = self.forward(x_batch)
46+
47+
if batch_idx == 0:
48+
with torch.no_grad():
49+
preds_sample = yhat_batch.view(-1)[:30]
50+
mean_pred = preds_sample.mean().item()
51+
min_pred = preds_sample.min().item()
52+
max_pred = preds_sample.max().item()
53+
self.log_dict({
54+
"debug/pred_mean": mean_pred,
55+
"debug/pred_min": min_pred,
56+
"debug/pred_max": max_pred
57+
}, prog_bar=True, on_step=True, logger=True)
58+
59+
4660
elif len(batch)==3:
4761
# KvsSample or 1vsSample
4862
x_batch, y_select, y_batch = batch
4963
yhat_batch = self.forward((x_batch,y_select))
50-
else:
51-
raise RuntimeError("Invalid batch received.")
64+
if batch_idx == 0:
65+
with torch.no_grad():
66+
preds_sample = yhat_batch.view(-1)[:30]
67+
mean_pred = preds_sample.mean().item()
68+
min_pred = preds_sample.min().item()
69+
max_pred = preds_sample.max().item()
70+
self.log_dict({
71+
"debug/pred_mean": mean_pred,
72+
"debug/pred_min": min_pred,
73+
"debug/pred_max": max_pred
74+
}, prog_bar=True, on_step=True, logger=True)
75+
76+
else:
77+
raise RuntimeError("Invalid batch received.")
78+
5279

5380
#total_norm = 0
5481
#for param in self.parameters():
@@ -202,8 +229,8 @@ def __init__(self, args: dict):
202229
if self.args["loss_fn"] == "ACLS":
203230
self.loss = ACLS()
204231
if self.args["loss_fn"] == "UNITEI":
205-
self.loss = UNITEI(gamma = self.args.get("unite_gamma", 5.0),
206-
sigma = self.args.get("unite_sigma", 1000.0))
232+
self.loss = UNITEI()
233+
207234
#if self.args["loss_fn"] == "GradientBasedLSLR":
208235
# self.loss = GradientBasedLSLR()
209236
#if self.args["loss_fn"] == "GradientBasedAdaptiveLSLR":

0 commit comments

Comments
 (0)