-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathutils.py
76 lines (63 loc) · 2.52 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import torch
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import numpy as np
import os
import json
def deprocess(img):
return img * 127.5 + 127.5
def plot_batch(tensor, plot_shape, filename, img_size=32):
tensor_np = tensor.permute(0, 2, 3, 1).numpy()
tensor_np = np.clip(tensor_np, 0, 255).astype(np.uint8)
rows = plot_shape[0]
columns = plot_shape[1]
# Proportional scale by img size
scale = (img_size / 90)
fig, axes = plt.subplots(rows, columns, figsize=(columns * scale, rows * scale))
# Iterate through each cell in the grid and plot images
for i in range(rows):
for j in range(columns):
# Calculate the index of the image
img_idx = i * columns + j
# If the image index exceeds the number of images, stop plotting
if img_idx >= tensor_np.shape[0]:
break
# Display the image in the respective subplot
ax = axes[i, j]
ax.imshow(tensor_np[img_idx])
ax.axis('off')
# Adjust the layout to minimize whitespace
plt.subplots_adjust(wspace=0, hspace=0, left=0, right=1, bottom=0, top=1)
plt.savefig(filename)
plt.close()
class Config(object):
def __init__(self, input_dict, save_dir):
file_path = os.path.join(save_dir, "config.json")
# Check if the configuration file exists
if os.path.exists(file_path):
self.load_config(file_path)
else:
for key, value in input_dict.items():
setattr(self, key, value)
self.save_config(file_path, save_dir)
self.print_variables()
def save_config(self, file_path, save_dir):
# Create the directory if it doesn't exist
os.makedirs(save_dir, exist_ok=True)
# Convert input_dict to JSON and save to file
with open(file_path, "w") as f:
json.dump(vars(self), f, indent=4)
print(f'New config {file_path} saved')
def load_config(self, file_path):
# Load configuration from the existing file
with open(file_path, "r") as f:
config_data = json.load(f)
# Update the object's attributes with loaded configuration
for key, value in config_data.items():
setattr(self, key, value)
print(f'Config {file_path} loaded')
def print_variables(self):
# Print all variables (attributes) of the Config object
for key, value in vars(self).items():
print(f"{key}: {value}")