-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathmain.py
30 lines (25 loc) · 821 Bytes
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
"""Put the code for your API here.
"""
from starter.train_model import trainer, get_data, batch_inference
CAT_FEATURES = [
"workclass",
"education",
"marital-status",
"occupation",
"relationship",
"race",
"sex",
"native-country",
]
if __name__ == '__main__':
data_path = 'data/cleaned_data.csv'
model_path = "model/random_forest_model_with_encoder_and_lb.pkl"
print(model_path)
# Get the splitted data
train_data, test_data = get_data(data_path)
# Training the model on the train data
trainer(train_data, model_path, CAT_FEATURES)
# evaluating the model on the test data
precision, recall, f_beta = batch_inference(test_data,
model_path,
CAT_FEATURES)