Skip to content

Commit

Permalink
update transform to return tensors; update train acc tracking
Browse files Browse the repository at this point in the history
  • Loading branch information
nkern committed Feb 27, 2024
1 parent 5983661 commit 3d7e57d
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 17 deletions.
13 changes: 7 additions & 6 deletions py21cmnet/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import numpy as np
from skimage import transform
from torch.utils.data import Dataset, DataLoader
import torch
from copy import deepcopy

from . import utils
Expand All @@ -28,15 +29,15 @@ def __call__(self, box, shift=None):
# compute shift if not fed
if shift is None:
if self.shift is None:
shift = np.random.randint(0, box[0].shape[-1], self.ndim)
shift = tuple(torch.randint(0, box[0].shape[-1], (self.ndim,)))
else:
shift = self.shift
if isinstance(box, (list, tuple)):
return [self.__call__(b, shift=shift) for b in box]
if self.ndim == 2:
return np.roll(box, shift, axis=(-1, -2))
return torch.roll(box, shift, dims=(-1, -2))
elif self.ndim == 3:
return np.roll(box, shift, axis=(-1, -2, -3))
return torch.roll(box, shift, dims=(-1, -2, -3))


class DownSample:
Expand Down Expand Up @@ -119,7 +120,7 @@ def __call__(self, box, axes=None):
return [self.__call__(b, axes=axes) for b in box]
# modify axes for full_dim
axes = tuple(range(dim_diff)) + tuple(np.array(axes) + dim_diff)
return np.transpose(box, axes)
return torch.transpose(box, axes)


class BoxDataset(Dataset):
Expand Down Expand Up @@ -301,9 +302,9 @@ def __call__(self, box, undo=False):
if isinstance(box, (list, tuple)):
return [self.__call__(b, undo=undo) for b in box]
if not undo:
log = np.log10 if self.log10 else np.log
log = torch.log10 if self.log10 else torch.log
return log((box - self.offset) / self.scale)
else:
func = (lambda x: 10**x) if self.log10 else np.exp
func = (lambda x: 10**x) if self.log10 else torch.exp
return func(box) * self.scale + self.offset

35 changes: 24 additions & 11 deletions py21cmnet/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from . import functional


def train(model, train_dloader, loss_fn, optim, optim_kwargs={},
def train(model, train_dloader, loss_fn, optim, optim_kwargs={}, track_mini=True,
acc_fn=None, Nepochs=1, valid_dloader=None, cuda=False, verbose=True):
"""
Model training function
Expand All @@ -31,8 +31,11 @@ def train(model, train_dloader, loss_fn, optim, optim_kwargs={},
Optimizer function
optim_kwargs : dict, default = {}
Optimizer function keyword arguments
tracK_mini : bool, default = True
If True, append loss and stats every mini-batch, otherwise
only track every epoch.
acc_fn : callable, default = None
Accuracy function
Accuracy function, taking acc_fn(pred_labels, true_labels)
Nepochs : int, default = 1
Number of training epochs
valid_dloader : DataLoader object, default = None
Expand All @@ -58,7 +61,7 @@ def train(model, train_dloader, loss_fn, optim, optim_kwargs={},
# iterate over epochs
for epoch in range(Nepochs):
if verbose:
print('Epoch {}/{}'.format(epoch, Nepochs))
print('Epoch {}/{}'.format(epoch+1, Nepochs))
print('-' * 10)

# training and validation
Expand All @@ -77,7 +80,6 @@ def train(model, train_dloader, loss_fn, optim, optim_kwargs={},

running_loss = 0.0
running_acc = 0.0
step = 1 # this should start at 1, not 0
optimizer.zero_grad()

# iterate over data
Expand All @@ -94,14 +96,15 @@ def train(model, train_dloader, loss_fn, optim, optim_kwargs={},

# backprop
loss.backward()
loss = loss.detach()

# step
optimizer.step()
optimizer.zero_grad()
else:
with torch.no_grad():
# compute model and loss
out = model(x)
out = model(X)
loss = loss_fn(out, y)

# compute accuracy
Expand All @@ -114,12 +117,19 @@ def train(model, train_dloader, loss_fn, optim, optim_kwargs={},
running_loss += loss * X.shape[0]

if i % 10 == 0 and verbose:
print('Current step: {} Loss: {}'.format(i, loss))
print('Current step: {} Loss: {}'.format(i, loss.cpu()))
if cuda:
print("AllocMem (Mb) {}".format(torch.cuda.memory_allocated()/1024/1024))
print(torch.cuda.memory_summary())

step += 1
if track_mini:
if phase == 'train':
train_loss.append(loss.cpu())
train_acc.append(acc.cpu())
else:
valid_loss.append(acc.cpu())
valid_acc.append(acc.cpu())


epoch_loss = running_loss / len(dataloader.dataset)
epoch_acc = running_acc / len(dataloader.dataset)
Expand All @@ -128,10 +138,13 @@ def train(model, train_dloader, loss_fn, optim, optim_kwargs={},
print('{} Loss: {:.4f} Acc: {}'.format(phase, epoch_loss, epoch_acc))
print('-' * 10)

if phase == 'train':
train_loss.append(epoch_loss)
else:
valid_loss.append(epoch_loss)
if not track_mini:
if phase == 'train':
train_loss.append(epoch_loss.cpu())
train_acc.append(epoch_acc.cpu())
else:
valid_loss.append(epoch_loss.cpu())
valid_acc.append(epoch_acc.cpu())

time_elapsed = time.time() - start
if verbose:
Expand Down

0 comments on commit 3d7e57d

Please sign in to comment.