Skip to content

Commit 3eed766

Browse files
committed
fix bug of deleting multipliers, now only do for erm
1 parent 61d7d98 commit 3eed766

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

domainlab/algos/trainers/a_trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -319,7 +319,7 @@ def cal_reg_loss(self, tensor_x, tensor_y, tensor_d, others=None):
319319
list_reg_loss_trainer_tensor
320320
list_mu = list_mu_model + list_mu_trainer
321321
# ERM return a tensor of all zeros, delete here
322-
if len(list_mu) > 1:
322+
if len(list_mu) > 1 and "ModelERM" == type(self.get_model()).__name__:
323323
list_boolean_zero = [torch.all(torch.eq(list_loss_tensor[i], 0)).item()
324324
for i in range(len(list_mu))]
325325
list_loss_tensor = [list_loss_tensor[i] for (i, flag) in

0 commit comments

Comments
 (0)