-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
88 lines (75 loc) · 2.49 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import os
import mlflow
import yaml
_steps = [
"download",
"clean",
"split",
"train",
"test",
]
def go(config: dict):
mlflow.set_tracking_uri(config["main"]["tracking_uri"])
mlflow.set_experiment(config["main"]["experiment_name"])
# Steps to execute
steps_par = config["main"]["steps"]
active_steps = steps_par.split(",") if steps_par != "all" else _steps
if "download" in active_steps:
_ = mlflow.run(
os.path.join("src", "get_data"),
"main",
parameters={
"uri": config["download"]["uri"],
"file_name": config["download"]["file_name"],
"s3_path": config["download"]["s3_path"],
},
)
if "clean" in active_steps:
_ = mlflow.run(
os.path.join("src", "clean_data"),
"main",
parameters={
"raw_data": config["clean"]["raw_data"],
"file_name": config["clean"]["file_name"],
"col_names": config["clean"]["col_names"],
"s3_path": config["clean"]["s3_path"],
},
)
if "split" in active_steps:
_ = mlflow.run(
os.path.join("src", "split_data"),
"main",
parameters={
"clean_data": config["split"]["clean_data"],
"test_size": config["split"]["test_size"],
"random_seed": config["split"]["random_seed"],
"file_names": config["split"]["file_names"],
"s3_path": config["split"]["s3_path"],
},
)
if "train" in active_steps:
_ = mlflow.run(
os.path.join("src", "train_model"),
"main",
parameters={
"train_data": config["train"]["train_data"],
"target": config["train"]["target"],
"model_config": config["train"]["model_config"],
"model_name": config["train"]["model_name"],
"s3_path": config["train"]["s3_path"],
},
)
if "test" in active_steps:
_ = mlflow.run(
os.path.join("src", "test_model"),
"main",
parameters={
"test_data": config["test"]["test_data"],
"target": config["test"]["target"],
"model_path": config["test"]["model_path"],
},
)
if __name__ == "__main__":
with open("config.yaml", "r") as fh:
conf = yaml.safe_load(fh)
go(conf)