-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #11 from Aura-healthcare/dev
Merge Dev into Main
- Loading branch information
Showing
72 changed files
with
168,630 additions
and
1,391 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -136,4 +136,5 @@ output/db/*csv | |
cloud/ | ||
tests/output/ | ||
exports/ | ||
output/*/* | ||
output/*/* | ||
data/data_pl |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
import os | ||
import sys | ||
from datetime import datetime as dt | ||
from sklearn.ensemble import RandomForestClassifier | ||
import datetime | ||
import xgboost as xgb | ||
import numpy as np | ||
|
||
PROJECT_FOLDER = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) | ||
DATA_FOLDER = os.path.join(PROJECT_FOLDER, 'data') | ||
|
||
ML_DATASET_OUTPUT_FOLDER = "/opt/airflow/output" | ||
AIRFLOW_PREFIX_TO_DATA = '/opt/airflow/data/' | ||
MLRUNS_DIR = '/mlruns' | ||
|
||
TRAIN_DATA = os.path.join(AIRFLOW_PREFIX_TO_DATA, "train/df_ml_train.csv") | ||
TEST_DATA = os.path.join(AIRFLOW_PREFIX_TO_DATA , "test/df_ml_test.csv") | ||
FEATURE_TRAIN_PATH= os.path.join(ML_DATASET_OUTPUT_FOLDER, "ml_train.csv") | ||
FEATURE_TEST_PATH= os.path.join(ML_DATASET_OUTPUT_FOLDER, "ml_test.csv") | ||
|
||
COL_TO_DROP = ['interval_index', 'interval_start_time', 'set'] | ||
|
||
START_DATE = dt(2021, 8, 1) | ||
CONCURRENCY = 4 | ||
SCHEDULE_INTERVAL = datetime.timedelta(hours=2) | ||
DEFAULT_ARGS = {'owner': 'airflow'} | ||
|
||
TRACKING_URI = 'http://mlflow:5000' | ||
|
||
MODEL_PARAM = { | ||
'model': xgb.XGBClassifier(), | ||
'grid_parameters': { | ||
'nthread':[4], | ||
'learning_rate': [0.1, 0.01, 0.05], | ||
'max_depth': np.arange(3, 5, 2), | ||
'scale_pos_weight':[1], | ||
'n_estimators': np.arange(15, 25, 2), | ||
'missing':[-999]} | ||
} | ||
|
||
MODELS_PARAM = { | ||
'xgboost': { | ||
'model': xgb.XGBClassifier(), | ||
'grid_parameters': { | ||
'nthread':[4], | ||
'learning_rate': [0.1, 0.01, 0.05], | ||
'max_depth': np.arange(3, 5, 2), | ||
'scale_pos_weight':[1], | ||
'n_estimators': np.arange(15, 25, 2), | ||
'missing':[-999] | ||
} | ||
}, | ||
'random_forest': { | ||
'model': RandomForestClassifier(), | ||
'grid_parameters': { | ||
'min_samples_leaf': np.arange(1, 5, 1), | ||
'max_depth': np.arange(1, 7, 1), | ||
'max_features': ['auto'], | ||
'n_estimators': np.arange(10, 20, 2)} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
import os | ||
import sys | ||
from datetime import datetime, timedelta, datetime | ||
|
||
from airflow.decorators import dag, task | ||
from airflow.utils.dates import days_ago | ||
|
||
sys.path.append('.') | ||
from dags.config import (DEFAULT_ARGS, START_DATE, CONCURRENCY, SCHEDULE_INTERVAL) | ||
|
||
|
||
@dag(default_args=DEFAULT_ARGS, | ||
start_date=START_DATE, | ||
schedule_interval=timedelta(minutes=2), | ||
concurrency=CONCURRENCY) | ||
def predict(): | ||
@task | ||
def prepare_features_with_io_task() -> str: | ||
pass | ||
|
||
@task | ||
def predict_with_io_task(feature_path: str) -> None: | ||
pass | ||
|
||
feature_path = prepare_features_with_io_task() | ||
predict_with_io_task(feature_path) | ||
|
||
predict_dag = predict() |
Oops, something went wrong.