-
Notifications
You must be signed in to change notification settings - Fork 2
/
create_pipeline.py
47 lines (40 loc) · 1.94 KB
/
create_pipeline.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import argparse
import yaml
from utils.pipeline_creation import model_config
import ast
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--base_dir', default='./')
parser.add_argument('-y', '--yaml', default='config/parameters.yaml', help='Config file YAML format')
parser.add_argument('-m', '--model_name', help='Choose a model name.')
parser.add_argument('-f', '--fine_tune_checkpoint', help='Define a checkpoint path to fine tuning.')
parser.add_argument('-b', '--batch_size', help='Batch size number.')
parser.add_argument('-o', '--output_filepath', help='Output pipeline config filepath.')
args = parser.parse_args()
try:
with open(args.yaml, 'r') as file:
config = yaml.safe_load(file)
except Exception as e:
print('Error reading the config file {}'.format(args.yaml))
print(e)
exit()
model_name = args.model_name if args.model_name else config['pipeline_config']['model_name']
fine_tune_checkpoint = args.fine_tune_checkpoint if args.fine_tune_checkpoint else None
batch_size = args.batch_size if args.batch_size else config['pipeline_config']['batch_size']
labelmap_path = config['pipeline_config']['labelmap_path']
train_record_path = config['pipeline_config']['train_record_path']
test_record_path = config['pipeline_config']['test_record_path']
num_classes = len(ast.literal_eval(config['pipeline_config']['classes_names']))
output_filepath = args.output_filepath if args.output_filepath else config['pipeline_config']['pipeline_config_filepath']
pipeline_config = model_config(
model_name,
labelmap_path,
num_classes,
train_record_path,
test_record_path,
batch_size,
output_filepath,
fine_tune_checkpoint
)
result_config_filepath = pipeline_config.create_pipeline_config()
print(f'Pipeline config created at {result_config_filepath}')