-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy path02.Augmentation.py
95 lines (79 loc) · 2.71 KB
/
02.Augmentation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import argparse
import os
from torchvision.io import read_image
from lib.images import analyze, flip, rotate, perspective
from lib.images import brightness, contrast, saturation, IMAGE_FOLDER
from lib.print import warning
import random
def print_summary(before, after):
"""
Print the number of images in each subdirectory
"""
print("========================")
print("Summary of augmentation:")
print("========================")
for key, value in before.items():
print(f"{key}: {value} -> {after[key]}")
def random_transform(img, file_path):
"""
Random augmentation
"""
transforms = [flip, rotate, perspective, brightness, contrast, saturation]
num = random.randint(0, 5)
transforms[num](img, file_path)
def augment(category: str):
"""
Augment images in the directory
"""
print(f'\nAugmenting "{category}" images...')
counts_before = analyze(category)
counts_after = counts_before.copy()
max_num = max(counts_after.values())
base = os.path.abspath(os.getcwd())
root, dirs, files = next(os.walk(IMAGE_FOLDER))
for dirname in dirs:
if not dirname.lower().startswith(category.lower()):
continue
dir_path = os.path.join(base, root, dirname)
while (counts_after[dir_path] < max_num):
for file in os.listdir(dir_path):
file_path = os.path.join(dir_path, file)
img = read_image(file_path)
random_transform(img, file_path)
counts_after[dir_path] += 1
if (counts_after[dir_path] == max_num):
break
# Reset the number of images with real number of files.
counts_after = analyze(category)
counts_after = analyze(category)
print_summary(counts_before, counts_after)
def main(file_path):
try:
if (file_path):
img = read_image(file_path)
flip(img, file_path)
rotate(img, file_path)
perspective(img, file_path)
brightness(img, file_path)
contrast(img, file_path)
saturation(img, file_path)
else:
print(warning("Auto image augmentation..."))
augment("Apple")
augment("Grape")
except Exception as e:
print(e)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="A program to augment images samples\
by applying 6 types of transformation."
)
# Adding argument for the directory
parser.add_argument(
'file_path',
type=str,
nargs='?',
help='Image file path to transform to 6 different types.',
)
args = parser.parse_args()
main(args.file_path)