-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathutility.py
59 lines (51 loc) · 1.63 KB
/
utility.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import os
import matplotlib.pyplot as plt
import argparse
import torch
from torch import nn
from torch import Tensor
from torch.nn.modules.module import Module
import torch.nn.functional as F
from torchvision import transforms as T
from torch.nn import Dropout, Softmax, Linear, Conv3d, LayerNorm, Flatten, Conv2d
import torchvision
from torchvision import transforms, models
import math
import copy
import torchsummary
from torch import autograd
from torch.autograd import Variable
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torchsummary import summary
import GPUtil
import nibabel as nib
import numpy as np
from tqdm import tqdm, trange
from itertools import cycle
import pandas as pd
from sklearn.model_selection import train_test_split
import cv2
from scipy.ndimage.filters import gaussian_filter
from scipy.stats import norm
from math import exp, sqrt
from PIL import Image
import wandb
from torchvision.utils import save_image
import torchvision.transforms.functional as FF
import lpips
parser = argparse.ArgumentParser()
parser.add_argument('--gpu_id', type=str, default="3")
parser.add_argument('--batch_size', type=int, default=64)
parser.add_argument('--epoch', type=int, default=101)
parser.add_argument('--lr_age', type=float, default=0.0005)
parser.add_argument('--lr_gan', type=float, default=0.0005)
parser.add_argument('--lr_map', type=float, default=0.00001)
parser.add_argument('--id_optim', type=str, default="Adam")
args = parser.parse_args()
def age_onehot(age):
if age.dim() == 0:
age = age.view(1)
z = torch.zeros(len(age), 101)
z[torch.arange(len(age)), age] = 1
return z