Skip to content

Commit

Permalink
wandb but how to sweeps?
Browse files Browse the repository at this point in the history
  • Loading branch information
Exr0n committed Jun 22, 2021
1 parent 008b6db commit e006bfc
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 22 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
**/data/
**/__pycache__/
.venv/
**/logs/
**/wandb/

Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,13 @@
from torch.nn import functional as F
from tqdm import tqdm

import wandb

from torch.utils.tensorboard import SummaryWriter

import csv, pickle
from pathlib import Path

def grouper(n, iterable):
bad = []
for i in iterable:
bad.append(i)
if len(bad) == n:
yield torch.stack(bad)
bad = []
yield torch.stack(bad)

def plot_grad_flow_bars(named_parameters):
# from @jemoka inscriptio gc
ave_grads = []
Expand Down Expand Up @@ -57,6 +50,15 @@ def plot_grad_flow_bars(named_parameters):
plt.show()

def load_data():
def grouper(n, iterable):
bad = []
for i in iterable:
bad.append(i)
if len(bad) == n:
yield torch.stack(bad)
bad = []
yield torch.stack(bad)

if Path(DATAPATH+'.pkl').exists():
print('found cached data; loading it...')
with open(DATAPATH+'.pkl', 'rb') as rf:
Expand Down Expand Up @@ -91,11 +93,11 @@ def load_data():
class Net(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 6, 5) # -> 6 x 44x44
self.pool = nn.MaxPool2d(2, 2) # -> 6 x 22x22; default stride = kernel_size
self.conv2 = nn.Conv2d(6, 16, 5) # -> 16 x 18x18
# pool again # -> 16 x 9 x 9
self.full1 = nn.Linear(16 * 9*9, 200)
self.conv1 = nn.Conv2d(1, 10, 5) # -> 10 x 44x44
self.pool = nn.MaxPool2d(2, 2) # -> 10 x 22x22; default stride = kernel_size
self.conv2 = nn.Conv2d(10, 20, 5) # -> 20 x 18x18
# pool again # -> 20 x 9 x 9
self.full1 = nn.Linear(20 * 9*9, 200)
self.full2 = nn.Linear(200, 70)
self.full3 = nn.Linear(70, 7)
self.final = nn.Softmax(dim=1)
Expand All @@ -110,25 +112,31 @@ def forward(self, x):
return self.final(x)

if __name__ == '__main__':
wandb.init(project='facial expression recognition')
print(f'pytorch version is {torch.__version__}')

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(f'training on device {device}')

data = list(load_data())

net = Net()
print(net)
print(f'parameter count: {len(list(net.parameters()))}')
net.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=LEARNING_RATE)


if SHOULD_LOG:
writer = SummaryWriter(LOGS_DIR) # TODO: with statement
wandb.watch(net)
# writer = SummaryWriter(LOGS_DIR) # TODO: with statement

with tqdm(total=EPOCHS*len(data)) as pbar:
for epoch in range(EPOCHS):
running_loss = 0. # TODO: nanny
for i, samp in enumerate(data):
img, cls = samp
img, cls = samp[0].to(device), samp[1].to(device)
# onehot = F.one_hot(cls, len(DATA_CLASSES)).type(torch.float32)

optimizer.zero_grad()
Expand All @@ -140,12 +148,14 @@ def forward(self, x):
# plot_grad_flow_bars(net.named_parameters())
optimizer.step()

running_loss += loss.item()
pbar.update(1)
pbar.set_description(f'step {epoch*len(data)+i}; loss {loss.item():.3f}')
if SHOULD_LOG:
writer.add_scalar('loss', loss.item(), epoch*len(data)+i)
if (epoch*len(data)+i) % int(1e2) == 0:
pbar.set_description(f'step {epoch*len(data)+i}; loss {loss.item():.3f}')
if SHOULD_LOG:
wandb.log({'loss': loss})
# writer.add_scalar('loss', loss.item(), epoch*len(data)+i)

if SHOULD_LOG:
writer.close()
pass
# writer.close()

41 changes: 41 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
absl-py==0.13.0
cachetools==4.2.2
certifi==2021.5.30
chardet==4.0.0
cycler==0.10.0
google-auth==1.32.0
google-auth-oauthlib==0.4.4
grpcio==1.38.0
idna==2.10
kiwisolver==1.3.1
llvmlite==0.36.0
Markdown==3.3.4
matplotlib==3.4.2
numba==0.53.1
numpy==1.20.3
oauthlib==3.1.1
pandas==1.2.4
Pillow==8.2.0
protobuf==3.17.3
pyasn1==0.4.8
pyasn1-modules==0.2.8
pyparsing==2.4.7
PyQt5==5.15.4
PyQt5-Qt5==5.15.2
PyQt5-sip==12.9.0
python-dateutil==2.8.1
pytz==2021.1
requests==2.25.1
requests-oauthlib==1.3.0
rsa==4.7.2
six==1.16.0
tensorboard==2.5.0
tensorboard-data-server==0.6.1
tensorboard-plugin-wit==1.8.0
torch==1.9.0+cu111
torchaudio==0.9.0
torchvision==0.10.0+cu111
tqdm==4.61.1
typing-extensions==3.10.0.0
urllib3==1.26.5
Werkzeug==2.0.1

0 comments on commit e006bfc

Please sign in to comment.