-
Notifications
You must be signed in to change notification settings - Fork 0
/
01_training.py
40 lines (31 loc) · 988 Bytes
/
01_training.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import logging
import dvc.api
import pandas as pd
import xgboost as xgb
from dvclive import Live
def main() -> None:
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(message)s",
)
logging.info("Start")
X_train = pd.read_parquet("data/dataset/X_train.parquet")
y_train = pd.read_parquet("data/dataset/y_train.parquet")
X_val = pd.read_parquet("data/dataset/X_val.parquet")
y_val = pd.read_parquet("data/dataset/y_val.parquet")
params = dvc.api.params_show()
logging.info("Training")
model = xgb.XGBClassifier(**params)
model.fit(X_train, y_train, eval_set=[(X_val, y_val)], verbose=False)
model.save_model("model.json")
with Live() as live:
live.log_artifact(
"model.json",
type="model",
name="mymodel",
desc="XGBoost",
labels=["xgboost"],
)
logging.info("End")
if __name__ == "__main__":
main()