-
Notifications
You must be signed in to change notification settings - Fork 165
/
utils.py
39 lines (37 loc) · 1.01 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import json
import argparse
import numpy as np
import torch.nn.functional as F
def save_args(args, to_path):
with open(to_path, "w") as f:
json.dump(args.__dict__, f, indent=2)
def load_args(from_path, is_test=True):
parser = argparse.ArgumentParser()
args = parser.parse_args()
with open(from_path, "r") as f:
args.__dict__ = json.load(f)
args.is_test = is_test
if "E_name" not in args.__dict__.keys():
args.E_name = "basic"
return args
def tensor2img(x):
'''
x : [BS x c x H x W] or [c x H x W]
'''
if x.ndim == 3:
x = x.unsqueeze(0)
BS, C, H, W = x.shape
x = x.permute(0,2,3,1).reshape(-1, W, C).detach().cpu().numpy()
# x = (x+1)/2
# x = np.clip(x, 0, 1)
x = np.clip(x, -1, 1)
x = (x+1)/2
x = np.uint8(x*255.0)
if x.shape[-1] == 1: # gray sclae
x = np.concatenate([x,x,x], axis=-1)
return x
def resize_mask(m, shape):
m = F.interpolate(m, shape)
m[m > 0.5] = 1
m[m < 0.5] = 0
return m