From 3c1652a891376249095353440fb39f17c539df25 Mon Sep 17 00:00:00 2001 From: Nicholas Kern Date: Wed, 6 Mar 2024 14:50:48 -0500 Subject: [PATCH] some minor bug fixes from changes in augmentations --- py21cmnet/dataset.py | 7 ++--- py21cmnet/models.py | 10 ++++--- py21cmnet/tests/test_dataset.py | 48 +++++++++++++++++---------------- py21cmnet/tests/test_models.py | 7 +++-- py21cmnet/utils.py | 6 ++++- 5 files changed, 46 insertions(+), 32 deletions(-) diff --git a/py21cmnet/dataset.py b/py21cmnet/dataset.py index 24d69e6..ccf5ce3 100644 --- a/py21cmnet/dataset.py +++ b/py21cmnet/dataset.py @@ -17,8 +17,9 @@ def __init__(self, shift=None, ndim=3): """roll a periodic box by "shift" pixels Args: shift : int or tuple - If roll a box by shift pixels along - the last ndim axes. Default is random. + Roll the box by this many pixels + along each of the specified dimensions. + Default is a random number per dimension. ndim : int Dimensionality of the box """ @@ -120,7 +121,7 @@ def __call__(self, box, axes=None): return [self.__call__(b, axes=axes) for b in box] # modify axes for full_dim axes = tuple(range(dim_diff)) + tuple(np.array(axes) + dim_diff) - return torch.transpose(box, axes) + return torch.permute(box, axes) class BoxDataset(Dataset): diff --git a/py21cmnet/models.py b/py21cmnet/models.py index bfffee7..6a15d0a 100644 --- a/py21cmnet/models.py +++ b/py21cmnet/models.py @@ -241,9 +241,13 @@ def crop_concat(self, X, connection): X = torch.cat([self.center_crop(connection, X.shape[-Nd:]), X], dim=1) return X - def forward(self, X, connection=None): + def forward(self, X, connection=None, metadata=None): if connection is not None: X = self.crop_concat(X, connection) + if metadata is not None: + shape = X.shape + shape[1] = len(metadata) + X = torch.cat([X, metadata.expand(shape)], dim=1) out = self.model(self.pass_to_device(X)) return out @@ -301,7 +305,7 @@ def __init__(self, encoder_layers, decoder_layers, self.final_transforms = final_transforms - def forward(self, X, debug=False): + def forward(self, X, debug=False, metadata=None): # pass through encoder connects = [] for i, encode in enumerate(self.encoder): @@ -320,7 +324,7 @@ def forward(self, X, debug=False): connection = connects[self.connections[i]] else: connection = None - X = decode(X, connection) + X = decode(X, connection, metadata=metadata if i == 0 else None) if debug: print("finished decoder block {}".format(i)) # final transformations diff --git a/py21cmnet/tests/test_dataset.py b/py21cmnet/tests/test_dataset.py index dae512e..cac7a73 100644 --- a/py21cmnet/tests/test_dataset.py +++ b/py21cmnet/tests/test_dataset.py @@ -19,8 +19,10 @@ def test_transforms(): db = utils.load_hdf5(fname + '/deltax', dtype=np.float32) box = utils.load_hdf5([fname + '/deltax', fname + '/Ts'], dtype=np.float32) + db, box = torch.as_tensor(db), torch.as_tensor(box) + # roll the cube - Roll = dataset.Roll(50, ndim=3) + Roll = dataset.Roll((50, 50, 50), ndim=3) assert Roll(db).shape == db.shape assert not (Roll(db) == db).any() @@ -30,9 +32,9 @@ def test_transforms(): # downsample DS = dataset.DownSample(2, ndim=3) - assert DS(db).shape == tuple(np.array(db.shape)/2) + assert DS(db).shape == torch.Size(np.array(db.shape)//2) assert (DS(db) == db[::2, ::2, ::2]).all() - assert DS(box).shape == box.shape[:1] + tuple(np.array(box.shape[1:])/2) + assert DS(box).shape == box.shape[:1] + torch.Size(np.array(box.shape[1:])//2) # transpose db_mod = db[:, ::2, ::4] @@ -76,9 +78,9 @@ def test_dataset(): dtype = np.float32 # simple load - X = utils.load_hdf5(Xfiles[0], dtype=dtype) - y = utils.load_hdf5(yfiles[0], dtype=dtype) - dl = dataset.BoxDataset(Xfiles, yfiles, utils.load_hdf5, dtype=dtype) + X = utils.load_hdf5_torch(Xfiles[0], dtype=dtype) + y = utils.load_hdf5_torch(yfiles[0], dtype=dtype) + dl = dataset.BoxDataset(Xfiles, yfiles, utils.load_hdf5_torch, dtype=dtype) assert len(dl) == 1 assert (dl[0][0] == X).all() assert (dl[0][1] == y).all() @@ -90,10 +92,10 @@ def test_dataset(): assert (dl[0][1] == y).all() # load with transformation - trans = Compose([dataset.Roll(shift=20, ndim=3), dataset.DownSample(thin=2, ndim=3)]) - dl = dataset.BoxDataset(Xfiles, yfiles, utils.load_hdf5, dtype=dtype, transform=trans) + trans = Compose([dataset.Roll(shift=(20,20,20), ndim=3), dataset.DownSample(thin=2, ndim=3)]) + dl = dataset.BoxDataset(Xfiles, yfiles, utils.load_hdf5_torch, dtype=dtype, transform=trans) assert len(dl) == 1 - assert dl[0][0].shape == X.shape[:1] + tuple(np.array(X.shape[1:])/2) + assert dl[0][0].shape == X.shape[:1] + torch.Size(np.array(X.shape[1:])//2) assert not (dl[0][0][0] == X[0, ::2, ::2, ::2]).any() def test_augmentations(): @@ -101,8 +103,8 @@ def test_augmentations(): Xfiles = [[fname+'/deltax', fname+'/Gamma']] yfiles = [[fname+'/x_HI', fname+'/Ts']] dtype = np.float32 - X = utils.load_hdf5(Xfiles[0], dtype=dtype) - y = utils.load_hdf5(yfiles[0], dtype=dtype) + X = utils.load_hdf5_torch(Xfiles[0], dtype=dtype) + y = utils.load_hdf5_torch(yfiles[0], dtype=dtype) # test single augmentation aug = dataset.Logarithm(offset=-1) @@ -124,8 +126,8 @@ def shift(x, undo=False): # try with dataset: only one augmentation for both X and y Xaugment, yaugment = dataset.Logarithm(offset=-1), dataset.Logarithm(offset=-1) - dl = dataset.BoxDataset(Xfiles, yfiles, utils.load_hdf5, dtype=dtype) - dl_aug = dataset.BoxDataset(Xfiles, yfiles, utils.load_hdf5, dtype=dtype, + dl = dataset.BoxDataset(Xfiles, yfiles, utils.load_hdf5_torch, dtype=dtype) + dl_aug = dataset.BoxDataset(Xfiles, yfiles, utils.load_hdf5_torch, dtype=dtype, X_augment=Xaugment, y_augment=yaugment) X, y = dl[0] Xaug, yaug = dl_aug[0] @@ -139,8 +141,8 @@ def shift(x, undo=False): # try with some augmentation for X and y channels Xaugment = [dataset.Logarithm(offset=-1), None] yaugment = [None, dataset.Logarithm()] - dl = dataset.BoxDataset(Xfiles, yfiles, utils.load_hdf5, dtype=dtype) - dl_aug = dataset.BoxDataset(Xfiles, yfiles, utils.load_hdf5, dtype=dtype, + dl = dataset.BoxDataset(Xfiles, yfiles, utils.load_hdf5_torch, dtype=dtype) + dl_aug = dataset.BoxDataset(Xfiles, yfiles, utils.load_hdf5_torch, dtype=dtype, X_augment=Xaugment, y_augment=yaugment) X, y = dl[0] Xaug, yaug = dl_aug[0] @@ -154,8 +156,8 @@ def shift(x, undo=False): assert np.isclose(y, dl_aug.augment(Xaug, yaug, undo=True)[1], atol=1e-6).all() # try with no aug for y and check memory location - _X = utils.load_hdf5(Xfiles[0], dtype=dtype) - _y = utils.load_hdf5(yfiles[0], dtype=dtype) + _X = utils.load_hdf5_torch(Xfiles[0], dtype=dtype) + _y = utils.load_hdf5_torch(yfiles[0], dtype=dtype) Xaugment, yaugment = dataset.Logarithm(offset=-1), None dl_aug = dataset.BoxDataset([_X], [_y], utils.load_dummy, X_augment=Xaugment, y_augment=yaugment) @@ -165,19 +167,19 @@ def shift(x, undo=False): assert hex(id(_y)) == hex(id(yaug)) # check inplace augmentation - _X = utils.load_hdf5(Xfiles[0], dtype=dtype) - _y = utils.load_hdf5(yfiles[0], dtype=dtype) + _X = utils.load_hdf5_torch(Xfiles[0], dtype=dtype) + _y = utils.load_hdf5_torch(yfiles[0], dtype=dtype) Xaugment, yaugment = dataset.Logarithm(offset=-1), dataset.Logarithm(offset=-1) dl_aug = dataset.BoxDataset([_X], [_y], utils.load_dummy, X_augment=Xaugment, y_augment=yaugment) Xaug, yaug = dl_aug.augment(_X, _y, inplace=True) # check that it did indeed augment - assert not np.isclose(_X, utils.load_hdf5(Xfiles[0], dtype=dtype), atol=1e-6).all() - assert not np.isclose(_y, utils.load_hdf5(yfiles[0], dtype=dtype), atol=1e-6).all() + assert not np.isclose(_X, utils.load_hdf5_torch(Xfiles[0], dtype=dtype), atol=1e-6).all() + assert not np.isclose(_y, utils.load_hdf5_torch(yfiles[0], dtype=dtype), atol=1e-6).all() # check memory address is the same assert hex(id(_X)) == hex(id(Xaug)) assert hex(id(_y)) == hex(id(yaug)) # check reverse inplace augmentation dl_aug.augment(_X, _y, undo=True, inplace=True) - assert np.isclose(_X, utils.load_hdf5(Xfiles[0], dtype=dtype), atol=1e-6).all() - assert np.isclose(_y, utils.load_hdf5(yfiles[0], dtype=dtype), atol=1e-6).all() + assert np.isclose(_X, utils.load_hdf5_torch(Xfiles[0], dtype=dtype), atol=1e-6).all() + assert np.isclose(_y, utils.load_hdf5_torch(yfiles[0], dtype=dtype), atol=1e-6).all() diff --git a/py21cmnet/tests/test_models.py b/py21cmnet/tests/test_models.py index 520024b..df971ad 100644 --- a/py21cmnet/tests/test_models.py +++ b/py21cmnet/tests/test_models.py @@ -97,8 +97,11 @@ def test_autoencoder(): dl = torch.utils.data.DataLoader(ds) info = utils.train(model, dl, torch.nn.MSELoss(reduction='mean'), torch.optim.Adam, optim_kwargs=dict(lr=0.1), Nepochs=3, verbose=True) - # assert loss decreases - assert (np.diff(torch.stack(info['train_loss']).detach().numpy()) < 0).all() + # assert average loss decreases + loss = torch.stack(info['train_loss']).cpu() + loss_start = torch.mean(loss[:len(loss)//2]) + loss_stop = torch.mean(loss[len(loss)//2:]) + assert loss_stop < loss_start # pred = model(X) # import matplotlib.pyplot as plt; diff --git a/py21cmnet/utils.py b/py21cmnet/utils.py index 0138d4c..181d4af 100644 --- a/py21cmnet/utils.py +++ b/py21cmnet/utils.py @@ -111,7 +111,7 @@ def train(model, train_dloader, loss_fn, optim, optim_kwargs={}, track_mini=True if acc_fn is not None: acc = acc_fn(out, y) else: - acc = 0 + acc = torch.tensor(0.) running_acc += acc * X.shape[0] running_loss += loss * X.shape[0] @@ -312,6 +312,10 @@ def load_hdf5(fname, dtype=None): return box +def load_hdf5_torch(*args, **kwargs): + return torch.as_tensor(load_hdf5(*args, **kwargs)) + + def _update_dict(d1, d2): for key in d2.keys(): if key in d1: