-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsubset_training_data.py
68 lines (54 loc) · 2.32 KB
/
subset_training_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import os
import random
import shutil
from concurrent.futures import ThreadPoolExecutor
def copy_file(file, destination):
shutil.copy2(file, destination)
def set_up_target_dirs(base_dir: str) -> (str, str, str):
"""
Creates a train, val, and test dir inside base_dir, with an images and labels subdir in each of them.
The directories are removed and recreated for repeatability.
:param base_dir:
:return: The train, val and test directory paths (relative to base_dir).
"""
train_dir = os.path.join(base_dir, 'train')
val_dir = os.path.join(base_dir, 'val')
test_dir = os.path.join(base_dir, 'test')
for dir_path in [train_dir, val_dir, test_dir]:
if os.path.exists(dir_path):
shutil.rmtree(dir_path)
os.makedirs(os.path.join(dir_path, 'images'))
os.makedirs(os.path.join(dir_path, 'labels'))
return train_dir, val_dir, test_dir
# set the seed and select random files
def copy_files_in_parallel(
labels_source_dir: str,
images_source_dir: str,
train_dir: str,
val_dir: str,
test_dir: str,
num_train_labels: int,
num_val_labels: int,
subset_size: int,
seed=1234
):
# list all label files in the source directory
all_files = [os.path.join(labels_source_dir, f) for f in os.listdir(labels_source_dir)]
random.seed(seed)
selected_files = random.sample(all_files, subset_size)
# Get the number of available CPUs
num_cpus = os.cpu_count()
with ThreadPoolExecutor(max_workers=num_cpus) as executor:
for index, source_label_file_path in enumerate(selected_files):
image_filename = os.path.basename(source_label_file_path).replace(".txt", ".jpg")
source_image_file_path = os.path.join(images_source_dir, image_filename)
if index < num_train_labels:
target_dir = train_dir
elif index < num_train_labels + num_val_labels:
target_dir = val_dir
else:
target_dir = test_dir
target_labels_dir = os.path.join(target_dir, 'labels')
target_images_dir = os.path.join(target_dir, 'images')
executor.submit(copy_file, source_label_file_path, target_labels_dir)
executor.submit(copy_file, source_image_file_path, target_images_dir)