-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtrain.py
60 lines (48 loc) · 1.74 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
# -*- coding:utf-8 -*-
import argparse
import os
import pathlib
import numpy as np
from skimage import io, img_as_float
from config import validation_rate
from trainer import train
def main(args):
# データ読み込み
path = pathlib.Path(f"{args.input_path}")
all_image_paths = [item.resolve() for item in path.glob("**/*") if item.is_file()]
all_images = np.array(
[
img_as_float(io.imread(path, as_gray=True))[:, :, np.newaxis]
for path in all_image_paths
]
)
if args.input_type == "mnist":
all_labels = [pathlib.Path(path).parent.name for path in all_image_paths]
elif args.input_type == "chinese":
all_labels = [
pathlib.Path(path).name.split(".")[0].split("_")[-1]
for path in all_image_paths
]
labels = list(set(all_labels))
label_index = {label: idx for idx, label in enumerate(labels)}
all_labels = np.array([label_index[label] for label in all_labels])
print(label_index)
train(
x_train=all_images,
y_train=all_labels,
label_index=label_index,
validation_rate=validation_rate,
output_dir=args.output_path,
log_dir=args.log_path,
)
if __name__ == "__main__":
# コマンドライン引数の設定
parser = argparse.ArgumentParser(description="aqualium demo")
parser.add_argument("--input_path", default="/kqi/input/images")
parser.add_argument("--output_path", default="/kqi/output/demo")
parser.add_argument("--log_path", default="/kqi/output/logs")
parser.add_argument("--input_type", default="mnist")
args = parser.parse_args()
os.makedirs(args.output_path, exist_ok=True)
os.makedirs(args.log_path, exist_ok=True)
main(args)