Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 67 additions & 2 deletions dicee/losses/custom_losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import math
import numpy as np
import os
import numpy as np
from torch import tensor

class DefaultBCELoss(nn.Module):
Expand Down Expand Up @@ -374,4 +373,70 @@ def forward(self, inputs, targets, current_epoch):
loss_reg = self.get_reg(inputs, targets)
loss = loss_ce + self.alpha * loss_reg

return loss
return loss



class UNITEI(nn.Module):
"""
UNITE-I loss (pattern-aware & noise-resilient) for mini-batch training.
"""

def __init__(
self,
gamma: float = 5,
sigma: float = 1000,
lambda_q: float = 0.3, # Regulaization term
clamp_scores: bool = True,
use_tanh_scale: bool = True,
):
super().__init__()
self.gamma = float(gamma)
self.sigma = float(sigma)
self.lambda_q = float(lambda_q)
self.clamp_scores = clamp_scores
self.use_tanh_scale = use_tanh_scale
self.eps = 1e-12

def _normalize_scores(self, pred: torch.Tensor) -> torch.Tensor:
# To keep gamma meaningful and gradients don’t die or explode.
x = pred
if self.use_tanh_scale:
# squashes to (-1,1), then scales to (-γ, γ)
x = torch.tanh(x) * self.gamma
if self.clamp_scores:
x = torch.clamp(x, min=-self.gamma, max=self.gamma)
return x

def forward(self, pred: torch.Tensor, target: torch.Tensor, current_epoch = None):

target = target.float()
pred = self._normalize_scores(pred)

tau_pos = F.relu(self.gamma - pred) # only where target==1 used
tau_neg = F.relu(pred - self.gamma) # only where target==0 used

loss_pos = self.sigma * (tau_pos ** 2)
loss_neg = self.sigma * (tau_neg ** 2)

# Constraint penalties with Q(x) = sigmoid(x)
# Pos: Q(gamma - f) >= Q(tau) → violation = ReLU(Q(tau) - Q(gamma - f)) where f=pred
# Neg: Q(f - gamma) >= Q(tau) → violation = ReLU(Q(tau) - Q(f - tau))
Q_tau_pos = torch.sigmoid(tau_pos)
Q_tau_neg = torch.sigmoid(tau_neg)
Q_pos = torch.sigmoid(self.gamma - pred)
Q_neg = torch.sigmoid(pred - self.gamma)

pos_violation = F.relu(Q_tau_pos - Q_pos)
neg_violation = F.relu(Q_tau_neg - Q_neg)

# Mask by labels and combine
pos_mask = target
neg_mask = 1.0 - target

loss = (
pos_mask * (loss_pos + self.lambda_q * pos_violation)
+ neg_mask * (loss_neg + self.lambda_q * neg_violation)
)

return loss.mean()
46 changes: 39 additions & 7 deletions dicee/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
CombinedAdaptiveLSandAdaptiveLR,
AggregatedLSandLR,
ACLS,
UNITEI,
#GradientBasedLSLR,
#GradientBasedAdaptiveLSLR
)
Expand All @@ -42,12 +43,39 @@ def training_step(self, batch, batch_idx=None):
# Default
x_batch, y_batch = batch
yhat_batch = self.forward(x_batch)

if batch_idx == 0:
with torch.no_grad():
preds_sample = yhat_batch.view(-1)[:30]
mean_pred = preds_sample.mean().item()
min_pred = preds_sample.min().item()
max_pred = preds_sample.max().item()
self.log_dict({
"debug/pred_mean": mean_pred,
"debug/pred_min": min_pred,
"debug/pred_max": max_pred
}, prog_bar=True, on_step=True, logger=True)
Comment on lines +47 to +57
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This block should be inside the if len(batch)==2. Also since it is duplicated code (with line 65-74) you can define a class method for it.



elif len(batch)==3:
# KvsSample or 1vsSample
x_batch, y_select, y_batch = batch
yhat_batch = self.forward((x_batch,y_select))
else:
raise RuntimeError("Invalid batch received.")
if batch_idx == 0:
with torch.no_grad():
preds_sample = yhat_batch.view(-1)[:30]
mean_pred = preds_sample.mean().item()
min_pred = preds_sample.min().item()
max_pred = preds_sample.max().item()
self.log_dict({
"debug/pred_mean": mean_pred,
"debug/pred_min": min_pred,
"debug/pred_max": max_pred
}, prog_bar=True, on_step=True, logger=True)

else:
raise RuntimeError("Invalid batch received.")


#total_norm = 0
#for param in self.parameters():
Expand Down Expand Up @@ -121,6 +149,7 @@ def configure_optimizers(self, parameters=None):
weight_decay=self.weight_decay)
elif self.optimizer_name == 'Adopt':
self.selected_optimizer = ADOPT(parameters, lr=self.learning_rate)

elif self.optimizer_name == 'AdamW':
self.selected_optimizer = torch.optim.AdamW(parameters, lr=self.learning_rate,
weight_decay=self.weight_decay)
Expand All @@ -137,7 +166,9 @@ def configure_optimizers(self, parameters=None):
weight_decay=self.weight_decay)
else:
raise KeyError(f"{self.optimizer_name} is not found!")

print(self.selected_optimizer)

return self.selected_optimizer


Expand Down Expand Up @@ -197,6 +228,9 @@ def __init__(self, args: dict):
self.loss = AggregatedLSandLR()
if self.args["loss_fn"] == "ACLS":
self.loss = ACLS()
if self.args["loss_fn"] == "UNITEI":
self.loss = UNITEI()

#if self.args["loss_fn"] == "GradientBasedLSLR":
# self.loss = GradientBasedLSLR()
#if self.args["loss_fn"] == "GradientBasedAdaptiveLSLR":
Expand Down Expand Up @@ -416,8 +450,7 @@ def get_triple_representation(self, idx_hrt):
# (1) Split input into indexes.
idx_head_entity, idx_relation, idx_tail_entity = idx_hrt[:, 0], idx_hrt[:, 1], idx_hrt[:, 2]
# (2) Retrieve embeddings & Apply Dropout & Normalization
head_ent_emb = self.normalize_head_entity_embeddings(
self.input_dp_ent_real(self.entity_embeddings(idx_head_entity)))
head_ent_emb = self.normalize_head_entity_embeddings(self.input_dp_ent_real(self.entity_embeddings(idx_head_entity)))
rel_ent_emb = self.normalize_relation_embeddings(self.input_dp_rel_real(self.relation_embeddings(idx_relation)))
tail_ent_emb = self.normalize_tail_entity_embeddings(self.entity_embeddings(idx_tail_entity))
return head_ent_emb, rel_ent_emb, tail_ent_emb
Expand All @@ -426,8 +459,7 @@ def get_head_relation_representation(self, indexed_triple):
# (1) Split input into indexes.
idx_head_entity, idx_relation = indexed_triple[:, 0], indexed_triple[:, 1]
# (2) Retrieve embeddings & Apply Dropout & Normalization
head_ent_emb = self.normalize_head_entity_embeddings(
self.input_dp_ent_real(self.entity_embeddings(idx_head_entity)))
head_ent_emb = self.normalize_head_entity_embeddings(self.input_dp_ent_real(self.entity_embeddings(idx_head_entity)))
rel_ent_emb = self.normalize_relation_embeddings(self.input_dp_rel_real(self.relation_embeddings(idx_relation)))
return head_ent_emb, rel_ent_emb

Expand Down Expand Up @@ -471,7 +503,7 @@ def get_bpe_head_and_relation_representation(self, x: torch.LongTensor) -> Tuple
# A sequence of sub-list embeddings representing an embedding of a head entity should be normalized to 0.
# Therefore, the norm of a row vector obtained from T by D matrix must be 1.
# B, T, D
head_ent_emb = F.normalize(head_ent_emb, p=2, dim=(1, 2))
head_ent_emb = F.normalize(head_ent_emb, p=2, dim=(1, 2)) #L2
# B, T, D
rel_emb = F.normalize(rel_emb, p=2, dim=(1, 2))
return head_ent_emb, rel_emb
Expand Down