-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathtrain_valid_split.py
91 lines (78 loc) · 3.63 KB
/
train_valid_split.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import os
import random
from shutil import copyfile
def split_data(input_dir, output_train, output_valid, training_set_ratio, simulate, combine_families):
validation_set_ratio = 1 - training_set_ratio
if not os.path.isdir(input_dir):
print(input_dir, 'Input directory not found. Exiting.')
exit(0)
os.chdir(input_dir) # the parent folder with sub-folders
iterations = None
list_classes = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()]
list_files = None
if(len(list_classes) == 0):
list_files = [f for f in os.listdir('.') if os.path.isfile(f)]
iterations = 1
print('No. of files found: ', len(list_files))
else:
iterations = len(list_classes)
print('No. classes found: ', len(list_classes))
os.chdir('..')
# dirs_create = [output_train, output_valid]
# for d in range(len(dirs_create)):
# if not os.path.isdir(dirs_create[d]):
# os.mkdir(dirs_create[d])
if not os.path.isdir(output_train):
os.makedirs(output_train)
if not os.path.isdir(output_valid):
os.makedirs(output_valid)
for i in range(iterations):
org_samples = None
if(len(list_classes) == 0):
org_samples = list_files
else:
cur_dir = input_dir + list_classes[i]
os.chdir(cur_dir)
org_samples = os.listdir(cur_dir)
inds = set(random.sample(list(range(len(org_samples))), int(validation_set_ratio * len(org_samples))))
train_samples = [n for i, n in enumerate(org_samples) if i not in inds]
validation_samples = list(set(org_samples) - set(train_samples))
# if(len(list_classes) != 0):
# print('class = %s org = %d, train = %d , valid = %d \n' %
# (list_classes[i], len(org_samples), len(train_samples), len(validation_samples)))
if(len(list_classes) != 0 and not combine_families):
dirs_create = [output_train + list_classes[i], output_valid + '/' + list_classes[i]]
for d in range(len(dirs_create)):
if not os.path.isdir(dirs_create[d]):
os.mkdir(dirs_create[d])
for t in range(len(train_samples)):
src = None
dst = None
if(len(list_classes) == 0):
src = input_dir + train_samples[t]
dst = output_train + train_samples[t]
elif(len(list_classes) != 0 and combine_families):
src = input_dir + list_classes[i] + '/' + train_samples[t]
dst = output_train + train_samples[t]
else:
src = input_dir + list_classes[i] + '/' + train_samples[t]
dst = output_train + list_classes[i] + '/' + train_samples[t]
# print('copying ', src, 'to', dst)
if not simulate:
copyfile(src, dst)
for t in range(len(validation_samples)):
src = None
dst = None
if(len(list_classes) == 0):
src = input_dir + validation_samples[t]
dst = output_valid + validation_samples[t]
elif(len(list_classes) != 0 and combine_families):
src = input_dir + list_classes[i] + '/' + validation_samples[t]
dst = output_valid + validation_samples[t]
else:
src = input_dir + list_classes[i] + '/' + validation_samples[t]
dst = output_valid + list_classes[i] + '/' + validation_samples[t]
# print('copying ', src, 'to', dst)
if not simulate:
copyfile(src, dst)
print("Done splitting.")