-
Notifications
You must be signed in to change notification settings - Fork 0
/
01_create_dataset_split.py
152 lines (125 loc) · 5.13 KB
/
01_create_dataset_split.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
#!/usr/bin/env python
# coding: utf-8
"""
Author : Aditya Jain
Date Started : July 20, 2022
Edited by : Katriona Goldmann, Levan Bokeria
About : Division of dataset into train, validation and test sets
"""
import os
import glob
import random
import pandas as pd
import argparse
def prepare_split_list(global_pd, new_list, fields):
"""
prepares a global csv list for every type of data split
Args:
global_pd: a global list into which new entries will be appended
new_list : list of new entries to be appended to global list
fields : contains the column names
"""
new_data = []
for path in new_list:
path_split = path.split("/")
filename = path_split[-1]
species = path_split[-2]
genus = path_split[-3]
family = path_split[-4]
new_data.append([filename, family, genus, species])
new_data = pd.DataFrame(new_data, columns=fields, dtype=object)
global_pd = pd.concat([global_pd, new_data], ignore_index=True)
return global_pd
def create_data_split(args):
"""main function for creating the dataset split"""
data_list = args.species_list
data = pd.read_csv(data_list, keep_default_na=False)
species_list = list(set(data["gbif_species_name"]))
data_dir = args.data_dir # root directory of data
write_dir = args.write_dir # split files to be written
train_spt = args.train_ratio # train set ratio
val_spt = args.val_ratio # validation set ration
test_spt = args.test_ratio # test set ratio
assert (
train_spt + val_spt + test_spt == 1
), "Train, val and test ratios should exactly sum to 1"
fields = ["filename", "family", "genus", "species"]
train_data = pd.DataFrame(columns=fields, dtype=object)
val_data = pd.DataFrame(columns=fields, dtype=object)
test_data = pd.DataFrame(columns=fields, dtype=object)
for family in os.listdir(data_dir):
if os.path.isdir(data_dir + "/" + family):
for genus in os.listdir(data_dir + family):
if os.path.isdir(data_dir + "/" + family + "/" + genus):
for species in os.listdir(data_dir + family + "/" + genus):
if species in species_list:
if os.path.isdir(
data_dir + "/" + family + "/" + genus + "/" + species
):
file_data = glob.glob(
data_dir
+ family
+ "/"
+ genus
+ "/"
+ species
+ "/*.jpg"
)
random.shuffle(file_data)
total = len(file_data)
train_amt = round(total * train_spt)
val_amt = round(total * val_spt)
train_list = file_data[:train_amt]
val_list = file_data[train_amt : train_amt + val_amt]
test_list = file_data[train_amt + val_amt :]
train_data = prepare_split_list(
train_data, train_list, fields
)
val_data = prepare_split_list(val_data, val_list, fields)
test_data = prepare_split_list(test_data, test_list, fields)
# saving the lists to disk
train_data.to_csv(write_dir + args.filename + "-train-split.csv", index=False)
val_data.to_csv(write_dir + args.filename + "-val-split.csv", index=False)
test_data.to_csv(write_dir + args.filename + "-test-split.csv", index=False)
# printing stats
print("Training data size: ", len(train_data))
print("Validation data size: ", len(val_data))
print("Testing data size: ", len(test_data))
print("Total images: ", len(train_data) + len(val_data) + len(test_data))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--data_dir",
help="path to the root directory containing the data",
required=True,
)
parser.add_argument(
"--write_dir",
help="path to the directory for saving the split files",
required=True,
)
parser.add_argument(
"--species_list",
help="path to the species list",
required=True,
)
parser.add_argument(
"--train_ratio",
help="proportion of data for training",
required=True,
type=float,
)
parser.add_argument(
"--val_ratio",
help="proportion of data for validation",
required=True,
type=float,
)
parser.add_argument(
"--test_ratio", help="proportion of data for testing", required=True, type=float
)
parser.add_argument(
"--filename", help="initial name for the split files", required=True
)
args = parser.parse_args()
create_data_split(args)