-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsplit_dataset.py
78 lines (62 loc) · 2.52 KB
/
split_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import os
import random
import shutil
import numpy as np
from PIL import Image
from tqdm import tqdm
"""
Function used to split the MiniImageNet dataset into training and validation sets, following the 80-20 rule.
"""
def split_dataset(root_dir, train_ratio=0.8):
train_dir = "Dataset/Split/Train"
test_dir = "Dataset/Split/Test"
for class_name in os.listdir(root_dir):
class_path = os.path.join(root_dir, class_name)
if not os.path.isdir(class_path):
continue
file_names = os.listdir(class_path)
random.shuffle(file_names)
split_index = int(len(file_names) * train_ratio)
train_files = file_names[:split_index]
test_files = file_names[split_index:]
train_class_dir = os.path.join(train_dir, class_name)
test_class_dir = os.path.join(test_dir, class_name)
os.makedirs(train_class_dir)
os.makedirs(test_class_dir)
for file_name in train_files:
src = os.path.join(class_path, file_name)
dst = os.path.join(train_class_dir, file_name)
shutil.copy(src, dst)
for file_name in test_files:
src = os.path.join(class_path, file_name)
dst = os.path.join(test_class_dir, file_name)
shutil.copy(src, dst)
print(f"Processed class '{class_name}' with {len(train_files)} train and {len(test_files)} test files.")
"""
Function used to compute the mean and the std of the dataset.
"""
def mean_std_dataset(dataset_path):
pixel_sum = np.zeros(3)
pixel_sq_sum = np.zeros(3)
num_pixels = 0
for class_folder in tqdm(os.listdir(dataset_path), desc="Processing classes"):
class_folder_path = os.path.join(dataset_path, class_folder)
if not os.path.isdir(class_folder_path):
continue
for image_name in os.listdir(class_folder_path):
image_path = os.path.join(class_folder_path, image_name)
with Image.open(image_path) as img:
img = img.convert("RGB")
img_np = np.array(img) / 255.0
pixel_sum += img_np.sum(axis=(0, 1))
pixel_sq_sum += (img_np ** 2).sum(axis=(0, 1))
num_pixels += img_np.shape[0] * img_np.shape[1]
mean = pixel_sum / num_pixels
std = np.sqrt(pixel_sq_sum / num_pixels - mean ** 2)
return mean, std
if __name__ == "__main__":
root_folder = "Dataset/CLEAR"
split_dataset(root_folder)
m, s = mean_std_dataset(root_folder)
print(f"Mean: {m}")
print(f"Standard Deviation: {s}")