Skip to content

Commit

Permalink
Blacking code (#145)
Browse files Browse the repository at this point in the history
  • Loading branch information
cweniger authored Sep 15, 2023
1 parent ba1b474 commit 393ed6c
Show file tree
Hide file tree
Showing 6 changed files with 26 additions and 30 deletions.
20 changes: 12 additions & 8 deletions swyft/lightning/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ class SwyftDataModule(pl.LightningDataModule):
def __init__(
self,
data,
#lengths: Union[Sequence[int], None] = None,
#fractions: Union[Sequence[float], None] = None,
# lengths: Union[Sequence[int], None] = None,
# fractions: Union[Sequence[float], None] = None,
val_fraction: float = 0.2,
batch_size: int = 32,
num_workers: int = 0,
Expand All @@ -54,13 +54,17 @@ def __init__(
self.data = data
# TODO: Clean up codes
lengths = None
fractions = [1-val_fraction, val_fraction]
fractions = [1 - val_fraction, val_fraction]
if lengths is not None and fractions is None:
assert len(lengths) == 2, "SwyftDataModule only provides training and validation data."
assert (
len(lengths) == 2
), "SwyftDataModule only provides training and validation data."
lengths = [lengths[0], lenghts[1], 0]
self.lengths = lengths
elif lengths is None and fractions is not None:
assert len(fractions) == 2, "SwyftDataModule only provides training and validation data."
assert (
len(fractions) == 2
), "SwyftDataModule only provides training and validation data."
fractions = [fractions[0], fractions[1], 0]
self.lengths = self._get_lengths(fractions, len(data))
else:
Expand Down Expand Up @@ -122,6 +126,8 @@ def val_dataloader(self):

def test_dataloader(self):
return


# dataloader = torch.utils.data.DataLoader(
# self.dataset_test,
# batch_size=self.batch_size,
Expand Down Expand Up @@ -186,9 +192,7 @@ def reset_length(self, N, clubber=False):
for k in self.data.keys():
shape = self.data[k].shape
self.data[k].resize(N, *shape[1:])
self.root["meta/sim_status"].resize(
N,
)
self.root["meta/sim_status"].resize(N,)

def init(self, N, chunk_size, shapes=None, dtypes=None):
if len(self) > 0:
Expand Down
18 changes: 9 additions & 9 deletions swyft/lightning/estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,24 +319,24 @@ def forward(self, x: torch.Tensor, z: torch.Tensor) -> torch.Tensor:

# log r(x, z) = log p(x, z)/p(x)/p(z), with covariance given by [[x_var, xz_cov], [xz_cov, z_var]]
x, z = swyft.equalize_tensors(x, z)
xb = (x - self.x_mean) / self.x_var**0.5
zb = (z - self.z_mean) / self.z_var**0.5
rho = self.xz_cov / self.x_var**0.5 / self.z_var**0.5
xb = (x - self.x_mean) / self.x_var ** 0.5
zb = (z - self.z_mean) / self.z_var ** 0.5
rho = self.xz_cov / self.x_var ** 0.5 / self.z_var ** 0.5
rho = torch.clip(
rho, min=-((1 - self.minstd**2) ** 0.5), max=(1 - self.minstd**2) ** 0.5
rho, min=-((1 - self.minstd ** 2) ** 0.5), max=(1 - self.minstd ** 2) ** 0.5
)
logratios = (
-0.5 * torch.log(1 - rho**2)
+ rho / (1 - rho**2) * xb * zb
- 0.5 * rho**2 / (1 - rho**2) * (xb**2 + zb**2)
-0.5 * torch.log(1 - rho ** 2)
+ rho / (1 - rho ** 2) * xb * zb
- 0.5 * rho ** 2 / (1 - rho ** 2) * (xb ** 2 + zb ** 2)
)
out = LogRatioSamples(
logratios, z.unsqueeze(-1), self.varnames, metadata={"type": "Gaussian1d"}
)
return out

def get_z_estimate(self, x):
z_estimator = (x - self.x_mean) * self.xz_cov / self.x_var**0.5 + self.z_mean
z_estimator = (x - self.x_mean) * self.xz_cov / self.x_var ** 0.5 + self.z_mean
return z_estimator


Expand Down Expand Up @@ -448,7 +448,7 @@ def _get_mean_cov(x, correction=1):
def cov(self):
return (
self._cov
+ torch.eye(self._mean.shape[-1]).to(self._cov.device) * self._minstd**2
+ torch.eye(self._mean.shape[-1]).to(self._cov.device) * self._minstd ** 2
)

@property
Expand Down
2 changes: 1 addition & 1 deletion swyft/lightning/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def _weighted_smoothed_histogramdd(v, w, bins=50, smooth=0):
low = v.min(axis=0).values
upp = v.max(axis=0).values
h = torchist.histogramdd(v, bins=bins, weights=w, low=low, upp=upp)
h /= len(v) * (upp[0] - low[0]) * (upp[1] - low[1]) / bins**2
h /= len(v) * (upp[0] - low[0]) * (upp[1] - low[1]) / bins ** 2
x = torch.linspace(low[0], upp[0], bins + 1)
y = torch.linspace(low[1], upp[1], bins + 1)
x = (x[1:] + x[:-1]) / 2
Expand Down
6 changes: 1 addition & 5 deletions swyft/networks/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,11 +233,7 @@ def get_marginal_classifier(
num_blocks=num_blocks,
)

return Network(
observation_transform,
parameter_transform,
marginal_classifier,
)
return Network(observation_transform, parameter_transform, marginal_classifier,)


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion swyft/networks/standardization.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def _parallel_algorithm(
m2b = (
x.var(dim=(0,), unbiased=False) * nb
) # do not use bessel's correction then multiply by total number of items in batch.
m2ab = m2a + m2b + delta**2 * na * nb / nab
m2ab = m2a + m2b + delta ** 2 * na * nb / nab
return nab, xab, m2ab

def forward(self, x: torch.Tensor) -> torch.Tensor:
Expand Down
8 changes: 2 additions & 6 deletions swyft/plot/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,11 +128,7 @@ def _plot_2d(
ax.axhline(truth[parname2], color="k", lw=1.0, zorder=10, ls=(1, (5, 1)))
if parname1 in truth.keys() and parname2 in truth.keys():
ax.scatter(
[truth[parname1]],
[truth[parname2]],
c="k",
marker=".",
s=100,
[truth[parname1]], [truth[parname2]], c="k", marker=".", s=100,
)


Expand Down Expand Up @@ -557,7 +553,7 @@ def plot_pair(
smooth=smooth,
cred_level=cred_level,
truth=truth,
smooth_prior=smooth_prior
smooth_prior=smooth_prior,
)
ax.set_xlabel(labels[k][0], **label_args)
ax.set_ylabel(labels[k][1], **label_args)
Expand Down

0 comments on commit 393ed6c

Please sign in to comment.