Skip to content

Commit

Permalink
some minor bug fixes from changes in augmentations
Browse files Browse the repository at this point in the history
  • Loading branch information
nkern committed Mar 6, 2024
1 parent 3d7e57d commit 3c1652a
Show file tree
Hide file tree
Showing 5 changed files with 46 additions and 32 deletions.
7 changes: 4 additions & 3 deletions py21cmnet/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@ def __init__(self, shift=None, ndim=3):
"""roll a periodic box by "shift" pixels
Args:
shift : int or tuple
If roll a box by shift pixels along
the last ndim axes. Default is random.
Roll the box by this many pixels
along each of the specified dimensions.
Default is a random number per dimension.
ndim : int
Dimensionality of the box
"""
Expand Down Expand Up @@ -120,7 +121,7 @@ def __call__(self, box, axes=None):
return [self.__call__(b, axes=axes) for b in box]
# modify axes for full_dim
axes = tuple(range(dim_diff)) + tuple(np.array(axes) + dim_diff)
return torch.transpose(box, axes)
return torch.permute(box, axes)


class BoxDataset(Dataset):
Expand Down
10 changes: 7 additions & 3 deletions py21cmnet/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,9 +241,13 @@ def crop_concat(self, X, connection):
X = torch.cat([self.center_crop(connection, X.shape[-Nd:]), X], dim=1)
return X

def forward(self, X, connection=None):
def forward(self, X, connection=None, metadata=None):
if connection is not None:
X = self.crop_concat(X, connection)
if metadata is not None:
shape = X.shape
shape[1] = len(metadata)
X = torch.cat([X, metadata.expand(shape)], dim=1)
out = self.model(self.pass_to_device(X))
return out

Expand Down Expand Up @@ -301,7 +305,7 @@ def __init__(self, encoder_layers, decoder_layers,

self.final_transforms = final_transforms

def forward(self, X, debug=False):
def forward(self, X, debug=False, metadata=None):
# pass through encoder
connects = []
for i, encode in enumerate(self.encoder):
Expand All @@ -320,7 +324,7 @@ def forward(self, X, debug=False):
connection = connects[self.connections[i]]
else:
connection = None
X = decode(X, connection)
X = decode(X, connection, metadata=metadata if i == 0 else None)
if debug: print("finished decoder block {}".format(i))

# final transformations
Expand Down
48 changes: 25 additions & 23 deletions py21cmnet/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@ def test_transforms():
db = utils.load_hdf5(fname + '/deltax', dtype=np.float32)
box = utils.load_hdf5([fname + '/deltax', fname + '/Ts'], dtype=np.float32)

db, box = torch.as_tensor(db), torch.as_tensor(box)

# roll the cube
Roll = dataset.Roll(50, ndim=3)
Roll = dataset.Roll((50, 50, 50), ndim=3)
assert Roll(db).shape == db.shape
assert not (Roll(db) == db).any()

Expand All @@ -30,9 +32,9 @@ def test_transforms():

# downsample
DS = dataset.DownSample(2, ndim=3)
assert DS(db).shape == tuple(np.array(db.shape)/2)
assert DS(db).shape == torch.Size(np.array(db.shape)//2)
assert (DS(db) == db[::2, ::2, ::2]).all()
assert DS(box).shape == box.shape[:1] + tuple(np.array(box.shape[1:])/2)
assert DS(box).shape == box.shape[:1] + torch.Size(np.array(box.shape[1:])//2)

# transpose
db_mod = db[:, ::2, ::4]
Expand Down Expand Up @@ -76,9 +78,9 @@ def test_dataset():
dtype = np.float32

# simple load
X = utils.load_hdf5(Xfiles[0], dtype=dtype)
y = utils.load_hdf5(yfiles[0], dtype=dtype)
dl = dataset.BoxDataset(Xfiles, yfiles, utils.load_hdf5, dtype=dtype)
X = utils.load_hdf5_torch(Xfiles[0], dtype=dtype)
y = utils.load_hdf5_torch(yfiles[0], dtype=dtype)
dl = dataset.BoxDataset(Xfiles, yfiles, utils.load_hdf5_torch, dtype=dtype)
assert len(dl) == 1
assert (dl[0][0] == X).all()
assert (dl[0][1] == y).all()
Expand All @@ -90,19 +92,19 @@ def test_dataset():
assert (dl[0][1] == y).all()

# load with transformation
trans = Compose([dataset.Roll(shift=20, ndim=3), dataset.DownSample(thin=2, ndim=3)])
dl = dataset.BoxDataset(Xfiles, yfiles, utils.load_hdf5, dtype=dtype, transform=trans)
trans = Compose([dataset.Roll(shift=(20,20,20), ndim=3), dataset.DownSample(thin=2, ndim=3)])
dl = dataset.BoxDataset(Xfiles, yfiles, utils.load_hdf5_torch, dtype=dtype, transform=trans)
assert len(dl) == 1
assert dl[0][0].shape == X.shape[:1] + tuple(np.array(X.shape[1:])/2)
assert dl[0][0].shape == X.shape[:1] + torch.Size(np.array(X.shape[1:])//2)
assert not (dl[0][0][0] == X[0, ::2, ::2, ::2]).any()

def test_augmentations():
fname = os.path.join(DATA_PATH, "train_21cmfast_basic.h5")
Xfiles = [[fname+'/deltax', fname+'/Gamma']]
yfiles = [[fname+'/x_HI', fname+'/Ts']]
dtype = np.float32
X = utils.load_hdf5(Xfiles[0], dtype=dtype)
y = utils.load_hdf5(yfiles[0], dtype=dtype)
X = utils.load_hdf5_torch(Xfiles[0], dtype=dtype)
y = utils.load_hdf5_torch(yfiles[0], dtype=dtype)

# test single augmentation
aug = dataset.Logarithm(offset=-1)
Expand All @@ -124,8 +126,8 @@ def shift(x, undo=False):

# try with dataset: only one augmentation for both X and y
Xaugment, yaugment = dataset.Logarithm(offset=-1), dataset.Logarithm(offset=-1)
dl = dataset.BoxDataset(Xfiles, yfiles, utils.load_hdf5, dtype=dtype)
dl_aug = dataset.BoxDataset(Xfiles, yfiles, utils.load_hdf5, dtype=dtype,
dl = dataset.BoxDataset(Xfiles, yfiles, utils.load_hdf5_torch, dtype=dtype)
dl_aug = dataset.BoxDataset(Xfiles, yfiles, utils.load_hdf5_torch, dtype=dtype,
X_augment=Xaugment, y_augment=yaugment)
X, y = dl[0]
Xaug, yaug = dl_aug[0]
Expand All @@ -139,8 +141,8 @@ def shift(x, undo=False):
# try with some augmentation for X and y channels
Xaugment = [dataset.Logarithm(offset=-1), None]
yaugment = [None, dataset.Logarithm()]
dl = dataset.BoxDataset(Xfiles, yfiles, utils.load_hdf5, dtype=dtype)
dl_aug = dataset.BoxDataset(Xfiles, yfiles, utils.load_hdf5, dtype=dtype,
dl = dataset.BoxDataset(Xfiles, yfiles, utils.load_hdf5_torch, dtype=dtype)
dl_aug = dataset.BoxDataset(Xfiles, yfiles, utils.load_hdf5_torch, dtype=dtype,
X_augment=Xaugment, y_augment=yaugment)
X, y = dl[0]
Xaug, yaug = dl_aug[0]
Expand All @@ -154,8 +156,8 @@ def shift(x, undo=False):
assert np.isclose(y, dl_aug.augment(Xaug, yaug, undo=True)[1], atol=1e-6).all()

# try with no aug for y and check memory location
_X = utils.load_hdf5(Xfiles[0], dtype=dtype)
_y = utils.load_hdf5(yfiles[0], dtype=dtype)
_X = utils.load_hdf5_torch(Xfiles[0], dtype=dtype)
_y = utils.load_hdf5_torch(yfiles[0], dtype=dtype)
Xaugment, yaugment = dataset.Logarithm(offset=-1), None
dl_aug = dataset.BoxDataset([_X], [_y], utils.load_dummy,
X_augment=Xaugment, y_augment=yaugment)
Expand All @@ -165,19 +167,19 @@ def shift(x, undo=False):
assert hex(id(_y)) == hex(id(yaug))

# check inplace augmentation
_X = utils.load_hdf5(Xfiles[0], dtype=dtype)
_y = utils.load_hdf5(yfiles[0], dtype=dtype)
_X = utils.load_hdf5_torch(Xfiles[0], dtype=dtype)
_y = utils.load_hdf5_torch(yfiles[0], dtype=dtype)
Xaugment, yaugment = dataset.Logarithm(offset=-1), dataset.Logarithm(offset=-1)
dl_aug = dataset.BoxDataset([_X], [_y], utils.load_dummy,
X_augment=Xaugment, y_augment=yaugment)
Xaug, yaug = dl_aug.augment(_X, _y, inplace=True)
# check that it did indeed augment
assert not np.isclose(_X, utils.load_hdf5(Xfiles[0], dtype=dtype), atol=1e-6).all()
assert not np.isclose(_y, utils.load_hdf5(yfiles[0], dtype=dtype), atol=1e-6).all()
assert not np.isclose(_X, utils.load_hdf5_torch(Xfiles[0], dtype=dtype), atol=1e-6).all()
assert not np.isclose(_y, utils.load_hdf5_torch(yfiles[0], dtype=dtype), atol=1e-6).all()
# check memory address is the same
assert hex(id(_X)) == hex(id(Xaug))
assert hex(id(_y)) == hex(id(yaug))
# check reverse inplace augmentation
dl_aug.augment(_X, _y, undo=True, inplace=True)
assert np.isclose(_X, utils.load_hdf5(Xfiles[0], dtype=dtype), atol=1e-6).all()
assert np.isclose(_y, utils.load_hdf5(yfiles[0], dtype=dtype), atol=1e-6).all()
assert np.isclose(_X, utils.load_hdf5_torch(Xfiles[0], dtype=dtype), atol=1e-6).all()
assert np.isclose(_y, utils.load_hdf5_torch(yfiles[0], dtype=dtype), atol=1e-6).all()
7 changes: 5 additions & 2 deletions py21cmnet/tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,11 @@ def test_autoencoder():
dl = torch.utils.data.DataLoader(ds)
info = utils.train(model, dl, torch.nn.MSELoss(reduction='mean'), torch.optim.Adam,
optim_kwargs=dict(lr=0.1), Nepochs=3, verbose=True)
# assert loss decreases
assert (np.diff(torch.stack(info['train_loss']).detach().numpy()) < 0).all()
# assert average loss decreases
loss = torch.stack(info['train_loss']).cpu()
loss_start = torch.mean(loss[:len(loss)//2])
loss_stop = torch.mean(loss[len(loss)//2:])
assert loss_stop < loss_start

# pred = model(X)
# import matplotlib.pyplot as plt;
Expand Down
6 changes: 5 additions & 1 deletion py21cmnet/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def train(model, train_dloader, loss_fn, optim, optim_kwargs={}, track_mini=True
if acc_fn is not None:
acc = acc_fn(out, y)
else:
acc = 0
acc = torch.tensor(0.)

running_acc += acc * X.shape[0]
running_loss += loss * X.shape[0]
Expand Down Expand Up @@ -312,6 +312,10 @@ def load_hdf5(fname, dtype=None):
return box


def load_hdf5_torch(*args, **kwargs):
return torch.as_tensor(load_hdf5(*args, **kwargs))


def _update_dict(d1, d2):
for key in d2.keys():
if key in d1:
Expand Down

0 comments on commit 3c1652a

Please sign in to comment.