Skip to content

Commit

Permalink
Update nonlinear.py
Browse files Browse the repository at this point in the history
  • Loading branch information
xunzheng authored Dec 17, 2020
1 parent 7786138 commit ba61337
Showing 1 changed file with 11 additions and 8 deletions.
19 changes: 11 additions & 8 deletions notears/nonlinear.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from notears.locally_connected import LocallyConnected
from notears.lbfgsb_scipy import LBFGSBScipy
from notears.trace_expm import trace_expm
import torch
import torch.nn as nn
import numpy as np
Expand Down Expand Up @@ -52,10 +53,11 @@ def h_func(self):
fc1_weight = self.fc1_pos.weight - self.fc1_neg.weight # [j * m1, i]
fc1_weight = fc1_weight.view(d, -1, d) # [j, m1, i]
A = torch.sum(fc1_weight * fc1_weight, dim=1).t() # [i, j]
# h = trace_expm(A) - d # (Zheng et al. 2018)
M = torch.eye(d) + A / d # (Yu et al. 2019)
E = torch.matrix_power(M, d - 1)
h = (E.t() * M).sum() - d
h = trace_expm(A) - d # (Zheng et al. 2018)
# A different formulation, slightly faster at the cost of numerical stability
# M = torch.eye(d) + A / d # (Yu et al. 2019)
# E = torch.matrix_power(M, d - 1)
# h = (E.t() * M).sum() - d
return h

def l2_reg(self):
Expand Down Expand Up @@ -130,10 +132,11 @@ def h_func(self):
fc1_weight = self.fc1_pos.weight - self.fc1_neg.weight # [j, ik]
fc1_weight = fc1_weight.view(self.d, self.d, self.k) # [j, i, k]
A = torch.sum(fc1_weight * fc1_weight, dim=2).t() # [i, j]
# h = trace_expm(A) - d # (Zheng et al. 2018)
M = torch.eye(self.d) + A / self.d # (Yu et al. 2019)
E = torch.matrix_power(M, self.d - 1)
h = (E.t() * M).sum() - self.d
h = trace_expm(A) - d # (Zheng et al. 2018)
# A different formulation, slightly faster at the cost of numerical stability
# M = torch.eye(self.d) + A / self.d # (Yu et al. 2019)
# E = torch.matrix_power(M, self.d - 1)
# h = (E.t() * M).sum() - self.d
return h

def l2_reg(self):
Expand Down

0 comments on commit ba61337

Please sign in to comment.