Skip to content

Commit

Permalink
update autoregressive model (#148)
Browse files Browse the repository at this point in the history
  • Loading branch information
NoemiAM authored Mar 30, 2024
1 parent 393ed6c commit 24b94bb
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion swyft/lightning/estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,7 @@ def __init__(
dropout=0.1,
num_blocks=2,
hidden_features=64,
min_l2 = None
):
super().__init__()
self.cl1 = swyft.LogRatioEstimator_1dim(
Expand All @@ -372,6 +373,7 @@ def __init__(
Lmax=0,
)
self.num_params = num_params
self.min_l2 = min_l2

def forward(self, xA, zA, zB):
xA, zB = swyft.equalize_tensors(xA, zB)
Expand All @@ -397,7 +399,8 @@ def forward(self, xA, zA, zB):

l1 = logratios1.logratios.sum(-1)
l2 = logratios2.logratios.sum(-1)
l2 = torch.where(l2 > 0, l2, 0)
if self.min_l2 is not None:
l2 = torch.where(l2 > self.min_l2, l2, self.min_l2)
l = (l1 - l2).detach().unsqueeze(-1)

logratios_tot = swyft.LogRatioSamples(l, logratios1.params, logratios1.parnames)
Expand Down

0 comments on commit 24b94bb

Please sign in to comment.