diff --git a/docs/mkdocs.yml b/docs/mkdocs.yml index 4a4d1841a..2cfd3384a 100644 --- a/docs/mkdocs.yml +++ b/docs/mkdocs.yml @@ -120,6 +120,7 @@ nav: - Using Postmodeling: postmodeling/index.md - Postmodeling & Crosstabs Configuration: postmodeling/postmodeling-config.md - Model governance: dirtyduck/ml_governance.md + -Predictlist: predictlist/index.md - Scaling up: dirtyduck/aws_batch.md - Database Provisioner: db.md - API Reference: diff --git a/docs/sources/predictlist/index.md b/docs/sources/predictlist/index.md new file mode 100644 index 000000000..c5abcacd2 --- /dev/null +++ b/docs/sources/predictlist/index.md @@ -0,0 +1,87 @@ +# Retrain and Predict +Use an existing model group to retrain a new model on all the data up to the current date and then predict forward into the future. + +## Examples +Both examples assume you have already run a Triage Experiment in the past, and know these two pieces of information: +1. A `model_group_id` from a Triage model group that you want to use to retrain a model and generate prediction +2. A `prediction_date` to generate your predictions on. + +### CLI +`triage retrainpredict ` + +Example: +`triage retrainpredict 30 2021-04-04` + +The `retrainpredict` will assume the current path to be the 'project path' to train models and write matrices, but this can be overridden by sending the `--project-path` option + +### Python +The `Retrainer` class from `triage.predictlist` module can be used to retrain a model and predict forward. + +```python +from triage.predictlist import Retrainer +from triage import create_engine + +retrainer = Retrainer( + db_engine=create_engine(), + project_path='/home/you/triage/project2' + model_group_id=36, +) +retrainer.retrain(prediction_date='2021-04-04') +retrainer.predict(prediction_date='2021-04-04') + +``` + +## Output +The retrained model is sotred similariy to the matrices created during an Experiment: +- Raw Matrix saved to the matrices directory in project storage +- Raw Model saved to the trained_model directory in project storage +- Retrained Model info saved in a table (triage_metadata.models) where model_comment = 'retrain_2021-04-04 21:19:09.975112' +- Predictions saved in a table (triage_production.predictions) +- Prediction metadata (tiebreaking, random seed) saved in a table (triage_produciton.prediction_metadata) + + +# Predictlist +If you would like to generate a list of predictions on already-trained Triage model with new data, you can use the 'Predictlist' module. + +# Predict Foward with Existed Model +Use an existing model object to generate predictions on new data. + +## Examples +Both examples assume you have already run a Triage Experiment in the past, and know these two pieces of information: +1. A `model_id` from a Triage model that you want to use to generate predictions +2. An `as_of_date` to generate your predictions on. + +### CLI +`triage predictlist ` + +Example: +`triage predictlist 46 2019-05-06` + +The predictlist will assume the current path to be the 'project path' to find models and write matrices, but this can be overridden by sending the `--project-path` option. + +### Python + +The `predict_forward_with_existed_model` function from the `triage.predictlist` module can be used similarly to the CLI, with the addition of the database engine and project storage as inputs. +``` +from triage.predictlist import generate predict_forward_with_existed_model +from triage import create_engine + +predict_forward_with_existed_model( + db_engine=create_engine(), + project_path='/home/you/triage/project2' + model_id=46, + as_of_date='2019-05-06' +) +``` + +## Output +The Predictlist is stored similarly to the matrices created during an Experiment: +- Raw Matrix saved to the matrices directory in project storage +- Predictions saved in a table (triage_production.predictions) +- Prediction metadata (tiebreaking, random seed) saved in a table (triage_production.prediction_metadata) + +## Notes +- The cohort and features for the Predictlist are all inferred from the Experiment that trained the given model_id (as defined by the experiment_models table). +- The feature list ensures that imputation flag columns are present for any columns that either needed to be imputed in the training process, or that needed to be imputed in the predictlist dataset. + + diff --git a/src/tests/catwalk_tests/test_model_trainers.py b/src/tests/catwalk_tests/test_model_trainers.py index 8e51c66c9..9b202cce0 100644 --- a/src/tests/catwalk_tests/test_model_trainers.py +++ b/src/tests/catwalk_tests/test_model_trainers.py @@ -60,7 +60,6 @@ def set_test_seed(): misc_db_parameters=dict(), matrix_store=get_matrix_store(project_storage), ) - # assert # 1. that the models and feature importances table entries are present records = [ @@ -286,11 +285,13 @@ def test_reuse_model_random_seeds(grid_config, default_model_trainer): def update_experiment_models(db_engine): sql = """ INSERT INTO triage_metadata.experiment_models(experiment_hash,model_hash) - SELECT m.built_by_experiment, m.model_hash - FROM triage_metadata.models m + SELECT er.run_hash, m.model_hash + FROM triage_metadata.models m + LEFT JOIN triage_metadata.triage_runs er + ON m.built_in_triage_run = er.id LEFT JOIN triage_metadata.experiment_models em - ON m.model_hash = em.model_hash - AND m.built_by_experiment = em.experiment_hash + ON m.model_hash = em.model_hash + AND er.run_hash = em.experiment_hash WHERE em.experiment_hash IS NULL """ db_engine.execute(sql) diff --git a/src/tests/collate_tests/test_collate.py b/src/tests/collate_tests/test_collate.py index a4585f20a..622b21582 100755 --- a/src/tests/collate_tests/test_collate.py +++ b/src/tests/collate_tests/test_collate.py @@ -4,6 +4,7 @@ Unit tests for `collate` module. """ +import pytest from triage.component.collate import Aggregate, Aggregation, Categorical def test_aggregate(): @@ -191,3 +192,54 @@ def test_distinct(): ), ) ) == ["count(distinct (x,y)) FILTER (WHERE date < '2012-01-01')"] + + +def test_Aggregation_colname_aggregate_lookup(): + n = Aggregate("x", "sum", {}) + d = Aggregate("1", "count", {}) + m = Aggregate("y", "avg", {}) + aggregation = Aggregation( + [n, d, m], + groups=['entity_id'], + from_obj="source", + prefix="mysource", + state_table="tbl" + ) + assert aggregation.colname_aggregate_lookup == { + 'mysource_entity_id_x_sum': 'sum', + 'mysource_entity_id_1_count': 'count', + 'mysource_entity_id_y_avg': 'avg' + } + +def test_Aggregation_colname_agg_function(): + n = Aggregate("x", "sum", {}) + d = Aggregate("1", "count", {}) + m = Aggregate("y", "stddev_samp", {}) + aggregation = Aggregation( + [n, d, m], + groups=['entity_id'], + from_obj="source", + prefix="mysource", + state_table="tbl" + ) + + assert aggregation.colname_agg_function('mysource_entity_id_x_sum') == 'sum' + assert aggregation.colname_agg_function('mysource_entity_id_y_stddev_samp') == 'stddev_samp' + + +def test_Aggregation_imputation_flag_base(): + n = Aggregate("x", ["sum", "count"], {}) + m = Aggregate("y", "stddev_samp", {}) + aggregation = Aggregation( + [n, m], + groups=['entity_id'], + from_obj="source", + prefix="mysource", + state_table="tbl" + ) + + assert aggregation.imputation_flag_base('mysource_entity_id_x_sum') == 'mysource_entity_id_x' + assert aggregation.imputation_flag_base('mysource_entity_id_x_count') == 'mysource_entity_id_x' + assert aggregation.imputation_flag_base('mysource_entity_id_y_stddev_samp') == 'mysource_entity_id_y_stddev_samp' + with pytest.raises(KeyError): + aggregation.imputation_flag_base('mysource_entity_id_x_stddev_samp') diff --git a/src/tests/postmodeling_tests/test_model_group_evaluator.py b/src/tests/postmodeling_tests/test_model_group_evaluator.py index 06a31482b..7db44e625 100644 --- a/src/tests/postmodeling_tests/test_model_group_evaluator.py +++ b/src/tests/postmodeling_tests/test_model_group_evaluator.py @@ -11,7 +11,7 @@ def model_group_evaluator(finished_experiment): def test_ModelGroupEvaluator_metadata(model_group_evaluator): assert isinstance(model_group_evaluator.metadata, list) - assert len(model_group_evaluator.metadata) == 8 # 8 model groups expected from basic experiment + assert len(model_group_evaluator.metadata) == 2 # 2 models expected for a model_group from basic experiment for row in model_group_evaluator.metadata: assert isinstance(row, dict) diff --git a/src/tests/results_tests/factories.py b/src/tests/results_tests/factories.py index 7b9c512ae..0b7cf2b2c 100644 --- a/src/tests/results_tests/factories.py +++ b/src/tests/results_tests/factories.py @@ -181,12 +181,12 @@ class Meta: matrix_uuid = factory.SelfAttribute("matrix_rel.matrix_uuid") -class ExperimentRunFactory(factory.alchemy.SQLAlchemyModelFactory): +class TriageRunFactory(factory.alchemy.SQLAlchemyModelFactory): class Meta: - model = schema.ExperimentRun + model = schema.TriageRun sqlalchemy_session = session - experiment_rel = factory.SubFactory(ExperimentFactory) + # experiment_rel = factory.SubFactory(ExperimentFactory) start_time = factory.fuzzy.FuzzyNaiveDateTime(datetime(2008, 1, 1)) start_method = "run" @@ -210,7 +210,7 @@ class Meta: models_skipped = 0 models_errored = 0 last_updated_time = factory.fuzzy.FuzzyNaiveDateTime(datetime(2008, 1, 1)) - current_status = schema.ExperimentRunStatus.started + current_status = schema.TriageRunStatus.started stacktrace = "" diff --git a/src/tests/test_cli.py b/src/tests/test_cli.py index 497059381..8df512414 100644 --- a/src/tests/test_cli.py +++ b/src/tests/test_cli.py @@ -2,6 +2,7 @@ import triage.cli as cli from unittest.mock import Mock, patch import os +import datetime # we do not need a real database URL but one SQLalchemy thinks looks like a real one @@ -56,3 +57,22 @@ def test_featuretest(): try_command('featuretest', 'example/config/experiment.yaml', '2017-06-06') featuremock.assert_called_once() cohortmock.assert_called_once() + + +def test_cli_predictlist(): + with patch('triage.cli.predict_forward_with_existed_model', autospec=True) as mock: + try_command('predictlist', '40', '2019-06-04') + mock.assert_called_once() + assert mock.call_args[0][0].url + assert mock.call_args[0][1] + assert mock.call_args[0][2] == 40 + assert mock.call_args[0][3] == datetime.datetime(2019, 6, 4) + + +def test_cli_retrain_predict(): + with patch('triage.cli.Retrainer', autospec=True) as mock: + try_command('retrainpredict', '3', '2021-04-04') + mock.assert_called_once() + assert mock.call_args[0][0].url + assert mock.call_args[0][1] + assert mock.call_args[0][2] == 3 diff --git a/src/tests/test_predictlist.py b/src/tests/test_predictlist.py new file mode 100644 index 000000000..9c95abece --- /dev/null +++ b/src/tests/test_predictlist.py @@ -0,0 +1,110 @@ +from triage.predictlist import Retrainer, predict_forward_with_existed_model, train_matrix_info_from_model_id, experiment_config_from_model_id +from triage.validation_primitives import table_should_have_data + + +def test_predict_forward_with_existed_model_should_write_predictions(finished_experiment): + # given a model id and as-of-date <= today + # and the model id is trained and is linked to an experiment with feature and cohort config + # generate records in triage_production.predictions + # the # of records should equal the size of the cohort for that date + model_id = 1 + as_of_date = '2014-01-01' + predict_forward_with_existed_model( + db_engine=finished_experiment.db_engine, + project_path=finished_experiment.project_storage.project_path, + model_id=model_id, + as_of_date=as_of_date + ) + table_should_have_data( + db_engine=finished_experiment.db_engine, + table_name="triage_production.predictions", + ) + + +def test_predict_forward_with_existed_model_should_be_same_shape_as_cohort(finished_experiment): + model_id = 1 + as_of_date = '2014-01-01' + predict_forward_with_existed_model( + db_engine=finished_experiment.db_engine, + project_path=finished_experiment.project_storage.project_path, + model_id=model_id, + as_of_date=as_of_date) + + num_records_matching_cohort = finished_experiment.db_engine.execute( + f'''select count(*) + from triage_production.predictions + join triage_production.cohort_{finished_experiment.config['cohort_config']['name']} using (entity_id, as_of_date) + ''' + ).first()[0] + + num_records = finished_experiment.db_engine.execute( + 'select count(*) from triage_production.predictions' + ).first()[0] + assert num_records_matching_cohort == num_records + + +def test_predict_forward_with_existed_model_matrix_record_is_populated(finished_experiment): + model_id = 1 + as_of_date = '2014-01-01' + predict_forward_with_existed_model( + db_engine=finished_experiment.db_engine, + project_path=finished_experiment.project_storage.project_path, + model_id=model_id, + as_of_date=as_of_date) + + matrix_records = list(finished_experiment.db_engine.execute( + "select * from triage_metadata.matrices where matrix_type = 'production'" + )) + assert len(matrix_records) == 1 + + +def test_experiment_config_from_model_id(finished_experiment): + model_id = 1 + experiment_config = experiment_config_from_model_id(finished_experiment.db_engine, model_id) + assert experiment_config == finished_experiment.config + + +def test_train_matrix_info_from_model_id(finished_experiment): + model_id = 1 + (train_matrix_uuid, matrix_metadata) = train_matrix_info_from_model_id(finished_experiment.db_engine, model_id) + assert train_matrix_uuid + assert matrix_metadata + + +def test_retrain_should_write_model(finished_experiment): + # given a model id and prediction_date + # and the model id is trained and is linked to an experiment with feature and cohort config + # create matrix for retraining a model + # generate records in production models + # retrain_model_hash should be the same with model_hash in triage_metadata.models + model_group_id = 1 + prediction_date = '2014-03-01' + + retrainer = Retrainer( + db_engine=finished_experiment.db_engine, + project_path=finished_experiment.project_storage.project_path, + model_group_id=model_group_id, + ) + retrain_info = retrainer.retrain(prediction_date) + model_comment = retrain_info['retrain_model_comment'] + + records = [ + row + for row in finished_experiment.db_engine.execute( + f"select model_hash from triage_metadata.models where model_comment = '{model_comment}'" + ) + ] + assert len(records) == 1 + assert retrainer.retrain_model_hash == records[0][0] + + retrainer.predict(prediction_date) + + table_should_have_data( + db_engine=finished_experiment.db_engine, + table_name="triage_production.predictions", + ) + + matrix_records = list(finished_experiment.db_engine.execute( + f"select * from triage_metadata.matrices where matrix_uuid = '{retrainer.predict_matrix_uuid}'" + )) + assert len(matrix_records) == 1 diff --git a/src/tests/test_tracking_experiments.py b/src/tests/test_tracking_experiments.py index 0fc2f064a..993560866 100644 --- a/src/tests/test_tracking_experiments.py +++ b/src/tests/test_tracking_experiments.py @@ -1,8 +1,8 @@ from tests.utils import sample_config, populate_source_data from triage.util.db import scoped_session from triage.experiments import MultiCoreExperiment, SingleThreadedExperiment -from triage.component.results_schema import ExperimentRun, ExperimentRunStatus -from tests.results_tests.factories import ExperimentFactory, ExperimentRunFactory, session as factory_session +from triage.component.results_schema import TriageRun, TriageRunStatus +from tests.results_tests.factories import ExperimentFactory, TriageRunFactory, session as factory_session from sqlalchemy.orm import Session import pytest import datetime @@ -30,9 +30,10 @@ def test_experiment_tracker(test_engine, project_path): project_path=project_path, n_processes=4, ) - experiment_run = Session(bind=test_engine).query(ExperimentRun).get(experiment.run_id) - assert experiment_run.current_status == ExperimentRunStatus.started - assert experiment_run.experiment_hash == experiment.experiment_hash + experiment_run = Session(bind=test_engine).query(TriageRun).get(experiment.run_id) + assert experiment_run.current_status == TriageRunStatus.started + assert experiment_run.run_hash == experiment.experiment_hash + assert experiment_run.run_type == 'experiment' assert experiment_run.experiment_class_path == 'triage.experiments.multicore.MultiCoreExperiment' assert experiment_run.platform assert experiment_run.os_user @@ -45,7 +46,7 @@ def test_experiment_tracker(test_engine, project_path): assert experiment_run.models_made == 0 experiment.run() - experiment_run = Session(bind=test_engine).query(ExperimentRun).get(experiment.run_id) + experiment_run = Session(bind=test_engine).query(TriageRun).get(experiment.run_id) assert experiment_run.start_method == "run" assert experiment_run.matrices_made == len(experiment.matrix_build_tasks) assert experiment_run.matrices_skipped == 0 @@ -57,7 +58,7 @@ def test_experiment_tracker(test_engine, project_path): assert isinstance(experiment_run.model_building_started, datetime.datetime) assert isinstance(experiment_run.last_updated_time, datetime.datetime) assert not experiment_run.stacktrace - assert experiment_run.current_status == ExperimentRunStatus.completed + assert experiment_run.current_status == TriageRunStatus.completed def test_experiment_tracker_exception(db_engine, project_path): @@ -71,8 +72,8 @@ def test_experiment_tracker_exception(db_engine, project_path): experiment.run() with scoped_session(db_engine) as session: - experiment_run = session.query(ExperimentRun).get(experiment.run_id) - assert experiment_run.current_status == ExperimentRunStatus.failed + experiment_run = session.query(TriageRun).get(experiment.run_id) + assert experiment_run.current_status == TriageRunStatus.failed assert isinstance(experiment_run.last_updated_time, datetime.datetime) assert experiment_run.stacktrace @@ -86,7 +87,7 @@ def test_experiment_tracker_in_parts(test_engine, project_path): experiment.generate_matrices() experiment.train_and_test_models() with scoped_session(test_engine) as session: - experiment_run = session.query(ExperimentRun).get(experiment.run_id) + experiment_run = session.query(TriageRun).get(experiment.run_id) assert experiment_run.start_method == "generate_matrices" @@ -103,8 +104,8 @@ def test_initialize_tracking_and_get_run_id(db_engine_with_results_schema): ) assert run_id with scoped_session(db_engine_with_results_schema) as session: - experiment_run = session.query(ExperimentRun).get(run_id) - assert experiment_run.experiment_hash == experiment_hash + experiment_run = session.query(TriageRun).get(run_id) + assert experiment_run.run_hash == experiment_hash assert experiment_run.experiment_class_path == 'mymodule.MyClassName' assert experiment_run.random_seed == 1234 assert experiment_run.experiment_kwargs == {'key': 'value'} @@ -119,7 +120,7 @@ def test_initialize_tracking_and_get_run_id(db_engine_with_results_schema): def test_get_run_for_update(db_engine_with_results_schema): - experiment_run = ExperimentRunFactory() + experiment_run = TriageRunFactory() factory_session.commit() with get_run_for_update( db_engine=db_engine_with_results_schema, @@ -128,16 +129,16 @@ def test_get_run_for_update(db_engine_with_results_schema): run_obj.stacktrace = "My stacktrace" with scoped_session(db_engine_with_results_schema) as session: - experiment_run_from_db = session.query(ExperimentRun).get(experiment_run.run_id) + experiment_run_from_db = session.query(TriageRun).get(experiment_run.run_id) assert experiment_run_from_db.stacktrace == "My stacktrace" def test_increment_field(db_engine_with_results_schema): - experiment_run = ExperimentRunFactory() + experiment_run = TriageRunFactory() factory_session.commit() increment_field('matrices_made', experiment_run.run_id, db_engine_with_results_schema) increment_field('matrices_made', experiment_run.run_id, db_engine_with_results_schema) with scoped_session(db_engine_with_results_schema) as session: - experiment_run_from_db = session.query(ExperimentRun).get(experiment_run.run_id) + experiment_run_from_db = session.query(TriageRun).get(experiment_run.run_id) assert experiment_run_from_db.matrices_made == 2 diff --git a/src/tests/utils.py b/src/tests/utils.py index 0fcdb41ef..aade835c5 100644 --- a/src/tests/utils.py +++ b/src/tests/utils.py @@ -439,7 +439,7 @@ def sample_config(): "label_config": label_config, "entity_column_name": "entity_id", "model_comment": "test2-final-final", - "model_group_keys": ["label_name", "label_type", "custom_key"], + "model_group_keys": ["label_name", "label_type", "custom_key", "class_path", "parameters"], "feature_aggregations": feature_config, "cohort_config": cohort_config, "temporal_config": temporal_config, diff --git a/src/triage/cli.py b/src/triage/cli.py index 0f10e1969..da30a3786 100755 --- a/src/triage/cli.py +++ b/src/triage/cli.py @@ -23,6 +23,7 @@ MultiCoreExperiment, SingleThreadedExperiment, ) +from triage.predictlist import predict_forward_with_existed_model, Retrainer from triage.component.postmodeling.crosstabs import CrosstabsConfigLoader, run_crosstabs from triage.component.postmodeling.utils.add_predictions import add_predictions from triage.util.db import create_engine @@ -415,6 +416,68 @@ def __call__(self, args): config = CrosstabsConfigLoader(config=yaml.full_load(fd)) run_crosstabs(db_engine, config) +@Triage.register +class RetrainPredict(Command): + """Given a model_group_id, retrain and predict forwoard use all data up to current date""" + + def __init__(self, parser): + parser.add_argument( + "model_group_id", + type=natural_number, + help="The model_group_id to use for retrain and predict" + ) + + parser.add_argument( + "prediction_date", + type=valid_date, + help="The date as of which to run features. Format YYYY-MM-DD", + ) + parser.add_argument( + "--project-path", + default=os.getcwd(), + help="path to store matrices and trained models", + ) + + def __call__(self, args): + db_engine = create_engine(self.root.db_url) + retrainer = Retrainer( + db_engine, + args.project_path, + args.model_group_id, + ) + retrainer.retrain(args.prediction_date) + retrainer.predict(args.prediction_date) + + +@Triage.register +class Predictlist(Command): + """Generate a list of risk scores from an already-trained model and new data""" + + def __init__(self, parser): + parser.add_argument( + "model_id", + type=natural_number, + help="The model_id of an existing trained model in the models table", + ) + parser.add_argument( + "as_of_date", + type=valid_date, + help="The date as of which to run features. Format YYYY-MM-DD", + ) + parser.add_argument( + "--project-path", + default=os.getcwd(), + help="path to store matrices and trained models", + ) + + def __call__(self, args): + db_engine = create_engine(self.root.db_url) + predict_forward_with_existed_model( + db_engine, + args.project_path, + args.model_id, + args.as_of_date + ) @Triage.register class Db(Command): diff --git a/src/triage/component/architect/builders.py b/src/triage/component/architect/builders.py index 90889b306..dd77cf4f6 100644 --- a/src/triage/component/architect/builders.py +++ b/src/triage/component/architect/builders.py @@ -1,5 +1,4 @@ import io -import json import verboselogs, logging logger = verboselogs.VerboseLogger(__name__) @@ -32,6 +31,7 @@ def __init__( self.replace = replace self.include_missing_labels_in_train_as = include_missing_labels_in_train_as self.run_id = run_id + self.includes_labels = 'labels_table_name' in self.db_config @property def sessionmaker(self): @@ -131,7 +131,7 @@ def make_entity_date_table( """ as_of_time_strings = [str(as_of_time) for as_of_time in as_of_times] - if matrix_type == "test" or self.include_missing_labels_in_train_as is not None: + if matrix_type == "test" or matrix_type == "production" or self.include_missing_labels_in_train_as is not None: indices_query = self._all_valid_entity_dates_query( as_of_time_strings=as_of_time_strings, state=state ) @@ -232,14 +232,15 @@ def build_matrix( if self.run_id: errored_matrix(self.run_id, self.db_engine) return - if not table_has_data( - f"{self.db_config['labels_schema_name']}.{self.db_config['labels_table_name']}", - self.db_engine, - ): - logger.warning("labels table is not populated, cannot build matrix") - if self.run_id: - errored_matrix(self.run_id, self.db_engine) - return + + if self.includes_labels: + if not table_has_data( + f"{self.db_config['labels_schema_name']}.{self.db_config['labels_table_name']}", + self.db_engine, + ): + logger.warning("labels table is not populated, cannot build matrix") + if self.run_id: + errored_matrix(self.run_id, self.db_engine) matrix_store = self.matrix_storage_engine.get_store(matrix_uuid) if not self.replace and matrix_store.exists: @@ -261,7 +262,7 @@ def build_matrix( matrix_metadata["state"], matrix_type, matrix_uuid, - matrix_metadata["label_timespan"], + matrix_metadata.get("label_timespan", None), ) except ValueError as e: logger.exception( @@ -277,19 +278,26 @@ def build_matrix( as_of_times, feature_dictionary, entity_date_table_name, matrix_uuid ) logger.debug(f"Feature data extracted for matrix {matrix_uuid}") - logger.spam( - "Extracting label data from database into file for matrix {matrix_uuid}", - ) - labels_df = self.load_labels_data( - label_name, - label_type, - entity_date_table_name, - matrix_uuid, - matrix_metadata["label_timespan"], - ) - dataframes.insert(0, labels_df) - logger.debug(f"Label data extracted for matrix {matrix_uuid}") + # dataframes add label_name + + if self.includes_labels: + logger.spam( + "Extracting label data from database into file for matrix {matrix_uuid}", + ) + labels_df = self.load_labels_data( + label_name, + label_type, + entity_date_table_name, + matrix_uuid, + matrix_metadata["label_timespan"], + ) + dataframes.insert(0, labels_df) + logging.debug(f"Label data extracted for matrix {matrix_uuid}") + else: + labels_df = pd.DataFrame(index=dataframes[0].index, columns=[label_name]) + dataframes.insert(0, labels_df) + # stitch together the csvs logger.spam(f"Merging feature files for matrix {matrix_uuid}") output = self.merge_feature_csvs(dataframes, matrix_uuid) diff --git a/src/triage/component/architect/feature_generators.py b/src/triage/component/architect/feature_generators.py index 86a0131d4..89f507241 100644 --- a/src/triage/component/architect/feature_generators.py +++ b/src/triage/component/architect/feature_generators.py @@ -606,7 +606,7 @@ def _generate_agg_table_tasks_for(self, aggregation): return table_tasks - def _generate_imp_table_tasks_for(self, aggregation, drop_preagg=True): + def _generate_imp_table_tasks_for(self, aggregation, impute_cols=None, nonimpute_cols=None, drop_preagg=True): """Generate SQL statements for preparing, populating, and finalizing imputations, for each feature group table in the given aggregation. @@ -653,8 +653,10 @@ def _generate_imp_table_tasks_for(self, aggregation, drop_preagg=True): with self.db_engine.begin() as conn: results = conn.execute(aggregation.find_nulls()) null_counts = results.first().items() - impute_cols = [col for (col, val) in null_counts if val > 0] - nonimpute_cols = [col for (col, val) in null_counts if val == 0] + if impute_cols is None: + impute_cols = [col for (col, val) in null_counts if val > 0] + if nonimpute_cols is None: + nonimpute_cols = [col for (col, val) in null_counts if val == 0] # table tasks for imputed aggregation table, most of the work is done here # by collate's get_impute_create() diff --git a/src/triage/component/architect/planner.py b/src/triage/component/architect/planner.py index 82e9220de..9a333a393 100644 --- a/src/triage/component/architect/planner.py +++ b/src/triage/component/architect/planner.py @@ -37,15 +37,17 @@ def _generate_build_task( "matrix_metadata": matrix_metadata, "matrix_type": matrix_metadata["matrix_type"], } - - def _make_metadata( - self, + + @staticmethod + def make_metadata( matrix_definition, feature_dictionary, label_name, label_type, cohort_name, matrix_type, + feature_start_time, + user_metadata, ): """ Generate dictionary of matrix metadata. @@ -77,7 +79,7 @@ def _make_metadata( ) matrix_metadata = { # temporal information - "feature_start_time": self.feature_start_time, + "feature_start_time": feature_start_time, "end_time": matrix_definition["matrix_info_end_time"], "as_of_date_frequency": matrix_definition.get( "training_as_of_date_frequency", @@ -100,7 +102,7 @@ def _make_metadata( "matrix_type": matrix_type, } matrix_metadata.update(matrix_definition) - matrix_metadata.update(self.user_metadata) + matrix_metadata.update(user_metadata) return matrix_metadata @@ -138,13 +140,15 @@ def generate_plans(self, matrix_set_definitions, feature_dictionaries): ): matrix_set_clone = copy.deepcopy(matrix_set) # get a uuid - train_metadata = self._make_metadata( + train_metadata = self.make_metadata( train_matrix, feature_dictionary, label_name, label_type, cohort_name, "train", + self.feature_start_time, + self.user_metadata, ) train_uuid = filename_friendly_hash(train_metadata) logger.debug( @@ -168,13 +172,15 @@ def generate_plans(self, matrix_set_definitions, feature_dictionaries): test_uuids = [] for test_matrix in matrix_set_clone["test_matrices"]: - test_metadata = self._make_metadata( + test_metadata = self.make_metadata( test_matrix, feature_dictionary, label_name, label_type, cohort_name, "test", + self.feature_start_time, + self.user_metadata, ) test_uuid = filename_friendly_hash(test_metadata) logger.debug( diff --git a/src/triage/component/catwalk/model_trainers.py b/src/triage/component/catwalk/model_trainers.py index 2c7f730b8..9a80ba3b7 100644 --- a/src/triage/component/catwalk/model_trainers.py +++ b/src/triage/component/catwalk/model_trainers.py @@ -187,6 +187,7 @@ def _write_model_to_db( model_group_id, model_size, misc_db_parameters, + retrain, ): """Writes model and feature importance data to a database Will overwrite the data of any previous versions @@ -210,22 +211,36 @@ def _write_model_to_db( misc_db_parameters (dict) params to pass through to the database """ model_id = retrieve_model_id_from_hash(self.db_engine, model_hash) - if model_id and not self.replace: + if model_id and not self.replace and not retrain: logger.notice( f"Metadata for model {model_id} found in database. Reusing model metadata." ) return model_id else: - model = Model( - model_hash=model_hash, - model_type=class_path, - hyperparameters=parameters, - model_group_id=model_group_id, - built_by_experiment=self.experiment_hash, - built_in_experiment_run=self.run_id, - model_size=model_size, - **misc_db_parameters, - ) + if retrain: + logger.debug("Retrain model...") + model = Model( + model_group_id=model_group_id, + model_hash=model_hash, + model_type=class_path, + hyperparameters=parameters, + # built_by_retrain=self.experiment_hash, + built_in_triage_run=self.run_id, + model_size=model_size, + **misc_db_parameters, + ) + + else: + model = Model( + model_hash=model_hash, + model_type=class_path, + hyperparameters=parameters, + model_group_id=model_group_id, + # built_by_experiment=self.experiment_hash, + built_in_triage_run=self.run_id, + model_size=model_size, + **misc_db_parameters, + ) session = self.sessionmaker() if model_id: logger.notice( @@ -240,7 +255,7 @@ def _write_model_to_db( model_id = model.model_id logger.notice(f"Model {model_id}, not found from previous runs. Adding the new model") session.close() - + logger.spam(f"Saving feature importances for model_id {model_id}") self._save_feature_importances( model_id, get_feature_importances(trained_model), feature_names @@ -249,7 +264,7 @@ def _write_model_to_db( return model_id def _train_and_store_model( - self, matrix_store, class_path, parameters, model_hash, misc_db_parameters, random_seed + self, matrix_store, class_path, parameters, model_hash, misc_db_parameters, random_seed, retrain, model_group_id, ): """Train a model, cache it, and write metadata to a database @@ -270,17 +285,27 @@ def _train_and_store_model( unique_parameters = self.unique_parameters(parameters) - model_group_id = self.model_grouper.get_model_group_id( - class_path, unique_parameters, matrix_store.metadata, self.db_engine - ) + + if retrain: + # if retrain, use the provided model_group_id + if not model_group_id: + raise ValueError("model_group_id should be provided when retrain") + + else: + model_group_id = self.model_grouper.get_model_group_id( + class_path, unique_parameters, matrix_store.metadata, self.db_engine + ) + + # Writing th model to storage, then getting its size in kilobytes. + self.model_storage_engine.write(trained_model, model_hash) + logger.debug( f"Trained model: hash {model_hash}, model group {model_group_id} " ) - # Writing th model to storage, then getting its size in kilobytes. - self.model_storage_engine.write(trained_model, model_hash) + logger.spam(f"Cached model: {model_hash}") + model_size = sys.getsizeof(trained_model) / (1024.0) - logger.spam(f"Cached model: {model_hash}") model_id = self._write_model_to_db( class_path, unique_parameters, @@ -290,9 +315,10 @@ def _train_and_store_model( model_group_id, model_size, misc_db_parameters, + retrain, ) logger.debug(f"Wrote model {model_id} [{model_hash}] to db") - return model_id + return model_id, model_hash @contextmanager def cache_models(self): @@ -350,7 +376,7 @@ def train_models(self, grid_config, misc_db_parameters, matrix_store): ] def process_train_task( - self, matrix_store, class_path, parameters, model_hash, misc_db_parameters, random_seed=None + self, matrix_store, class_path, parameters, model_hash, misc_db_parameters, random_seed=None, retrain=False, model_group_id=None, ): """Trains and stores a model, or skips it and returns the existing id @@ -387,8 +413,8 @@ def process_train_task( f"(reason to train: {reason})" ) try: - model_id = self._train_and_store_model( - matrix_store, class_path, parameters, model_hash, misc_db_parameters, random_seed + model_id, model_hash = self._train_and_store_model( + matrix_store, class_path, parameters, model_hash, misc_db_parameters, random_seed, retrain, model_group_id ) except BaselineFeatureNotInMatrix: logger.warning( diff --git a/src/triage/component/catwalk/storage.py b/src/triage/component/catwalk/storage.py index 1dbca9694..81d60a790 100644 --- a/src/triage/component/catwalk/storage.py +++ b/src/triage/component/catwalk/storage.py @@ -23,8 +23,10 @@ TrainEvaluation, TestPrediction, TrainPrediction, + ListPrediction, TestPredictionMetadata, TrainPredictionMetadata, + ListPredictionMetadata, TestAequitas, TrainAequitas ) @@ -454,7 +456,7 @@ def columns(self, include_label=False): if include_label: return columns else: - return [col for col in columns if col != self.metadata["label_name"]] + return [col for col in columns if col != self.metadata.get("label_name", None)] @property def label_column_name(self): @@ -498,6 +500,8 @@ def matrix_type(self): return TrainMatrixType elif self.metadata["matrix_type"] == "test": return TestMatrixType + elif self.metadata["matrix_type"] == "production": + return ProductionMatrixType else: raise Exception( """matrix metadata for matrix {} must contain 'matrix_type' @@ -544,7 +548,10 @@ def matrix_with_sorted_columns(self, columns): @property def full_matrix_for_saving(self): - return self.design_matrix.assign(**{self.label_column_name: self.labels}) + if self.labels is not None: + return self.design_matrix.assign(**{self.label_column_name: self.labels}) + else: + return self.design_matrix def load_metadata(self): """Load metadata from storage""" @@ -610,3 +617,10 @@ class TrainMatrixType: aequitas_obj = TrainAequitas prediction_metadata_obj = TrainPredictionMetadata is_test = False + + +class ProductionMatrixType(object): + string_name = "production" + prediction_obj = ListPrediction + prediction_metadata_obj = ListPredictionMetadata + diff --git a/src/triage/component/catwalk/utils.py b/src/triage/component/catwalk/utils.py index 38b22a0d6..f83216848 100644 --- a/src/triage/component/catwalk/utils.py +++ b/src/triage/component/catwalk/utils.py @@ -24,7 +24,7 @@ Model, ExperimentMatrix, ExperimentModel, - ExperimentRun, + TriageRun, ) @@ -246,13 +246,13 @@ def retrieve_existing_model_random_seeds(db_engine, model_group_id, train_end_ti from {ExperimentModel.__table__.fullname} experiment_models join {Model.__table__.fullname} models on (experiment_models.model_hash = models.model_hash) - join {ExperimentRun.__table__.fullname} experiment_runs - on (experiment_models.experiment_hash = experiment_runs.experiment_hash) + join {TriageRun.__table__.fullname} triage_runs + on (experiment_models.experiment_hash = triage_runs.run_hash) where models.model_group_id = %s and models.train_end_time = %s and models.train_matrix_uuid = %s and models.training_label_timespan = %s - and experiment_runs.random_seed = %s + and triage_runs.random_seed = %s order by models.run_time DESC, random() """ return [row[0] for row in db_engine.execute(query, model_group_id, train_end_time, train_matrix_uuid, training_label_timespan, experiment_random_seed)] @@ -269,7 +269,7 @@ def retrieve_experiment_seed_from_run_id(db_engine, run_id): """ session = sessionmaker(bind=db_engine)() try: - return session.query(ExperimentRun).get(run_id).random_seed + return session.query(TriageRun).get(run_id).random_seed finally: session.close() diff --git a/src/triage/component/collate/collate.py b/src/triage/component/collate/collate.py index 804642bc6..457d8d1a2 100644 --- a/src/triage/component/collate/collate.py +++ b/src/triage/component/collate/collate.py @@ -29,6 +29,10 @@ } +class NoAggregateFunctionError(ValueError): + pass + + def make_list(a): return [a] if not isinstance(a, list) else a @@ -497,6 +501,24 @@ def colname_aggregate_lookup(self): lookup[col.name] = agg return lookup + def colname_agg_function(self, colname): + if colname.endswith('_imp'): + raise ValueError('Imputation flag columns cannot have their aggregation function inferred') + + aggregate = self.colname_aggregate_lookup[colname] + if hasattr(aggregate, 'functions'): + used_function = next(funcname for funcname in aggregate.functions if colname.endswith(funcname)) + return used_function + else: + raise NoAggregateFunctionError() + + def imputation_flag_base(self, colname): + used_function = self.colname_agg_function(colname) + if used_function in AGGFUNCS_NEED_MULTIPLE_VALUES: + return colname + else: + return colname.rstrip('_' + used_function) + def _col_prefix(self, group): """ Helper for creating a column prefix for the group @@ -726,18 +748,14 @@ def _get_impute_select(self, impute_cols, nonimpute_cols, partitionby=None): # the function, and see its available functions. we expect exactly one of # these functions to end the column name and remove it if so # this is passed to the imputer - if hasattr(self.colname_aggregate_lookup[col], 'functions'): - agg_functions = self.colname_aggregate_lookup[col].functions - used_function = next(funcname for funcname in agg_functions if col.endswith(funcname)) - if used_function in AGGFUNCS_NEED_MULTIPLE_VALUES: - impflag_basecol = col - else: - impflag_basecol = col.rstrip('_' + used_function) - else: + try: + impflag_basecol = self.imputation_flag_base(col) + except NoAggregateFunctionError: logger.warning("Imputation flag merging is not implemented for " "AggregateExpression objects that don't define an aggregate " "function (e.g. composites)") impflag_basecol = col + impute_rule = imprules[col] try: diff --git a/src/triage/component/results_schema/__init__.py b/src/triage/component/results_schema/__init__.py index 40bf26007..3d02cb695 100644 --- a/src/triage/component/results_schema/__init__.py +++ b/src/triage/component/results_schema/__init__.py @@ -13,14 +13,16 @@ from .schema import ( Base, Experiment, + Retrain, FeatureImportance, IndividualImportance, ListPrediction, ExperimentMatrix, Matrix, ExperimentModel, - ExperimentRun, - ExperimentRunStatus, + RetrainModel, + TriageRun, + TriageRunStatus, Model, ModelGroup, Subset, @@ -30,6 +32,7 @@ TrainPrediction, TestPredictionMetadata, TrainPredictionMetadata, + ListPredictionMetadata, TrainAequitas, TestAequitas ) @@ -38,14 +41,16 @@ __all__ = ( "Base", "Experiment", + "Retrain", "FeatureImportance", "IndividualImportance", "ListPrediction", "ExperimentMatrix", "Matrix", + "RetrainModel", "ExperimentModel", - "ExperimentRun", - "ExperimentRunStatus", + "TriageRun", + "TriageRunStatus", "Model", "ModelGroup", "Subset", @@ -55,6 +60,7 @@ "TrainPrediction", "TestPredictionMetadata", "TrainPredictionMetadata", + "ListPredictionMetadata", "TestAequitas", "TrainAequitas", "mark_db_as_upgraded", diff --git a/src/triage/component/results_schema/alembic/versions/079a74c15e8b_merge_b097e47ba829_with_cdd0dc9d9870.py b/src/triage/component/results_schema/alembic/versions/079a74c15e8b_merge_b097e47ba829_with_cdd0dc9d9870.py new file mode 100644 index 000000000..73021015a --- /dev/null +++ b/src/triage/component/results_schema/alembic/versions/079a74c15e8b_merge_b097e47ba829_with_cdd0dc9d9870.py @@ -0,0 +1,24 @@ +"""merge b097e47ba829 with cdd0dc9d9870 + +Revision ID: 079a74c15e8b +Revises: b097e47ba829, cdd0dc9d9870 +Create Date: 2021-05-30 20:49:19.039280 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '079a74c15e8b' +down_revision = ('b097e47ba829', 'cdd0dc9d9870') +branch_labels = None +depends_on = None + + +def upgrade(): + pass + + +def downgrade(): + pass diff --git a/src/triage/component/results_schema/alembic/versions/1b990cbc04e4_production_schema.py b/src/triage/component/results_schema/alembic/versions/1b990cbc04e4_production_schema.py new file mode 100644 index 000000000..b6e1b060e --- /dev/null +++ b/src/triage/component/results_schema/alembic/versions/1b990cbc04e4_production_schema.py @@ -0,0 +1,26 @@ +"""empty message + +Revision ID: 1b990cbc04e4 +Revises: 0bca1ba9706e +Create Date: 2019-02-20 16:41:22.810452 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '1b990cbc04e4' +down_revision = '45219f25072b' +branch_labels = None +depends_on = None + + +def upgrade(): + op.execute("CREATE SCHEMA IF NOT EXISTS production") + op.execute("ALTER TABLE triage_metadata.list_predictions SET SCHEMA production;") + + +def downgrade(): + op.execute("ALTER TABLE production.list_predictions SET SCHEMA triage_metadata;") + op.execute("DROP SCHEMA IF EXISTS production") diff --git a/src/triage/component/results_schema/alembic/versions/264786a9fe85_add_label_value_to_prodcution_table.py b/src/triage/component/results_schema/alembic/versions/264786a9fe85_add_label_value_to_prodcution_table.py new file mode 100644 index 000000000..fbeb48e6d --- /dev/null +++ b/src/triage/component/results_schema/alembic/versions/264786a9fe85_add_label_value_to_prodcution_table.py @@ -0,0 +1,54 @@ +"""add label_value to prodcution table + +Revision ID: 264786a9fe85 +Revises: 1b990cbc04e4 +Create Date: 2019-02-26 13:17:05.365654 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '264786a9fe85' +down_revision = '1b990cbc04e4' +branch_labels = None +depends_on = None + + +def upgrade(): + op.drop_table("list_predictions", schema="production") + op.create_table( + "list_predictions", + sa.Column("model_id", sa.Integer(), nullable=False), + sa.Column("entity_id", sa.BigInteger(), nullable=False), + sa.Column("as_of_date", sa.DateTime(), nullable=False), + sa.Column("score", sa.Numeric(), nullable=True), + sa.Column('label_value', sa.Integer, nullable=True), + sa.Column("rank_abs", sa.Integer(), nullable=True), + sa.Column("rank_pct", sa.Float(), nullable=True), + sa.Column("matrix_uuid", sa.Text(), nullable=True), + sa.Column("test_label_timespan", sa.Interval(), nullable=True), + sa.ForeignKeyConstraint(["model_id"], ["triage_metadata.models.model_id"]), + sa.PrimaryKeyConstraint("model_id", "entity_id", "as_of_date"), + schema="production", + ) + + +def downgrade(): + op.drop_table("list_predictions", schema="production") + op.create_table( + "list_predictions", + sa.Column("model_id", sa.Integer(), nullable=False), + sa.Column("entity_id", sa.BigInteger(), nullable=False), + sa.Column("as_of_date", sa.DateTime(), nullable=False), + sa.Column("score", sa.Numeric(), nullable=True), + sa.Column("rank_abs", sa.Integer(), nullable=True), + sa.Column("rank_pct", sa.Float(), nullable=True), + sa.Column("matrix_uuid", sa.Text(), nullable=True), + sa.Column("test_label_timespan", sa.Interval(), nullable=True), + sa.ForeignKeyConstraint(["model_id"], ["triage_metadata.models.model_id"]), + sa.PrimaryKeyConstraint("model_id", "entity_id", "as_of_date"), + schema="results", + ) + diff --git a/src/triage/component/results_schema/alembic/versions/5dd2ba8222b1_add_run_type.py b/src/triage/component/results_schema/alembic/versions/5dd2ba8222b1_add_run_type.py new file mode 100644 index 000000000..81b36615a --- /dev/null +++ b/src/triage/component/results_schema/alembic/versions/5dd2ba8222b1_add_run_type.py @@ -0,0 +1,58 @@ +"""add run_type + +Revision ID: 5dd2ba8222b1 +Revises: 079a74c15e8b +Create Date: 2021-07-22 23:53:04.043651 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = '5dd2ba8222b1' +down_revision = '079a74c15e8b' +branch_labels = None +depends_on = None + + +def upgrade(): + op.add_column('experiment_runs', sa.Column('run_type', sa.Text(), nullable=True), schema='triage_metadata') + op.execute("UPDATE triage_metadata.experiment_runs SET run_type='experiment' WHERE run_type IS NULL") + + op.alter_column('experiment_runs', 'experiment_hash', nullable=True, new_column_name='run_hash', schema='triage_metadata') + op.drop_constraint('experiment_runs_experiment_hash_fkey', 'experiment_runs', type_='foreignkey', schema='triage_metadata') + + op.execute("ALTER TABLE triage_metadata.experiment_runs RENAME TO triage_runs") + + op.create_table('retrain', + sa.Column('retrain_hash', sa.Text(), nullable=False), + sa.Column('config', postgresql.JSONB(astext_type=sa.Text()), nullable=True), + sa.Column('prediction_date', sa.DateTime(), nullable=True), + sa.PrimaryKeyConstraint('retrain_hash'), + schema='triage_metadata', + ) + + op.alter_column('models', 'built_in_experiment_run', nullable=False, new_column_name='built_in_triage_run', schema='triage_metadata') + op.execute("CREATE TABLE triage_metadata.deprecated_models_built_by_experiment AS SELECT model_id, model_hash, built_by_experiment FROM triage_metadata.models") + op.drop_column('models', 'built_by_experiment', schema='triage_metadata') + + op.create_table('retrain_models', + sa.Column('retrain_hash', sa.String(), nullable=False), + sa.Column('model_hash', sa.String(), nullable=False), + sa.ForeignKeyConstraint(['retrain_hash'], ['triage_metadata.retrain.retrain_hash'], ), + sa.PrimaryKeyConstraint('retrain_hash', 'model_hash'), + schema='triage_metadata' + ) + + +def downgrade(): + op.execute("ALTER TABLE triage_metadata.triage_runs RENAME TO experiment_runs") + op.drop_column('experiment_runs', 'run_type', schema='triage_metadata') + op.alter_column('experiment_runs', 'run_hash', nullable=True, new_column_name='experiment_hash', schema='triage_metadata') + op.create_foreign_key('experiment_runs_experiment_hash_fkey', 'experiment_runs', 'experiments', ['experiment_hash'], ['experiment_hash'], source_schema='triage_metadata', referent_schema='triage_metadata') + op.drop_table('retrain_models', schema='triage_metadata') + op.drop_table('retrain', schema='triage_metadata') + op.add_column('models', sa.Column('built_by_experiment', sa.Text(), nullable=True), schema='triage_metadata') + op.alter_column('models', 'built_in_triage_run', nullable=False, new_column_name='built_in_experiment_run', schema='triage_metadata') + diff --git a/src/triage/component/results_schema/alembic/versions/670289044eb2_add_production_prediction_metadata.py b/src/triage/component/results_schema/alembic/versions/670289044eb2_add_production_prediction_metadata.py new file mode 100644 index 000000000..7146142c2 --- /dev/null +++ b/src/triage/component/results_schema/alembic/versions/670289044eb2_add_production_prediction_metadata.py @@ -0,0 +1,38 @@ +"""Add production prediction metadata + +Revision ID: 670289044eb2 +Revises: ce5b50ffa8e2 +Create Date: 2021-01-08 22:27:23.433813 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = '670289044eb2' +down_revision = 'ce5b50ffa8e2' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('prediction_metadata', + sa.Column('model_id', sa.Integer(), nullable=False), + sa.Column('matrix_uuid', sa.Text(), nullable=False), + sa.Column('tiebreaker_ordering', sa.Text(), nullable=True), + sa.Column('random_seed', sa.Integer(), nullable=True), + sa.Column('predictions_saved', sa.Boolean(), nullable=True), + sa.ForeignKeyConstraint(['matrix_uuid'], ['triage_metadata.matrices.matrix_uuid'], ), + sa.ForeignKeyConstraint(['model_id'], ['triage_metadata.models.model_id'], ), + sa.PrimaryKeyConstraint('model_id', 'matrix_uuid'), + schema='production' + ) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table('prediction_metadata', schema='production') + # ### end Alembic commands ### diff --git a/src/triage/component/results_schema/alembic/versions/cdd0dc9d9870_rename_production_schema_and_prediction_table.py b/src/triage/component/results_schema/alembic/versions/cdd0dc9d9870_rename_production_schema_and_prediction_table.py new file mode 100644 index 000000000..173e3e117 --- /dev/null +++ b/src/triage/component/results_schema/alembic/versions/cdd0dc9d9870_rename_production_schema_and_prediction_table.py @@ -0,0 +1,30 @@ +"""rename production schema and list_predcitons to triage_predcition and predictions + +Revision ID: cdd0dc9d9870 +Revises: 670289044eb2 +Create Date: 2021-04-13 00:53:56.098572 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = 'cdd0dc9d9870' +down_revision = '670289044eb2' +branch_labels = None +depends_on = None + + +def upgrade(): + op.execute("CREATE SCHEMA IF NOT EXISTS triage_production") + op.execute("ALTER TABLE production.list_predictions SET SCHEMA triage_production;") + op.execute("ALTER TABLE production.prediction_metadata SET SCHEMA triage_production") + op.execute("ALTER TABLE triage_production.list_predictions RENAME TO predictions") + + +def downgrade(): + op.execute("ALTER TABLE triage_production.predictions SET SCHEMA production;") + op.execute("ALTER TABLE triage_production.prediction_metadata SET SCHEMA production") + op.execute("ALTER TABLE production.predictions RENAME TO list_predictions") + op.execute("DROP SCHEMA IF EXISTS triage_production") diff --git a/src/triage/component/results_schema/alembic/versions/ce5b50ffa8e2_break_ties_in_list_predictions.py b/src/triage/component/results_schema/alembic/versions/ce5b50ffa8e2_break_ties_in_list_predictions.py new file mode 100644 index 000000000..6870ff9b7 --- /dev/null +++ b/src/triage/component/results_schema/alembic/versions/ce5b50ffa8e2_break_ties_in_list_predictions.py @@ -0,0 +1,34 @@ +"""Break ties in list predictions + +Revision ID: ce5b50ffa8e2 +Revises: 264786a9fe85 +Create Date: 2021-01-08 21:59:13.403934 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = 'ce5b50ffa8e2' +down_revision = '264786a9fe85' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('list_predictions', sa.Column('rank_abs_with_ties', sa.Integer(), nullable=True), schema='production') + op.add_column('list_predictions', sa.Column('rank_pct_with_ties', sa.Float(), nullable=True), schema='production') + op.alter_column('list_predictions', 'rank_abs', new_column_name='rank_abs_no_ties', schema='production') + op.alter_column('list_predictions', 'rank_pct', new_column_name='rank_pct_no_ties', schema='production') + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.alter_column('list_predictions', 'rank_abs_no_ties', new_column_name='rank_abs', schema='production') + op.alter_column('list_predictions', 'rank_pct_no_ties', new_column_name='rank_pct', schema='production') + op.drop_column('list_predictions', 'rank_pct_with_ties', schema='production') + op.drop_column('list_predictions', 'rank_abs_with_ties', schema='production') + # ### end Alembic commands ### diff --git a/src/triage/component/results_schema/schema.py b/src/triage/component/results_schema/schema.py index 0f2c53ed6..08694e298 100644 --- a/src/triage/component/results_schema/schema.py +++ b/src/triage/component/results_schema/schema.py @@ -23,6 +23,7 @@ from sqlalchemy.orm import relationship from sqlalchemy.types import ARRAY, Enum from sqlalchemy.sql import func +from sqlalchemy.ext.hybrid import hybrid_property # One declarative_base object for each schema created Base = declarative_base() @@ -31,6 +32,7 @@ "CREATE SCHEMA IF NOT EXISTS triage_metadata;" " CREATE SCHEMA IF NOT EXISTS test_results;" " CREATE SCHEMA IF NOT EXISTS train_results;" + " CREATE SCHEMA IF NOT EXISTS triage_production;" ) event.listen(Base.metadata, "before_create", DDL(schemas)) @@ -58,7 +60,10 @@ class Experiment(Base): __tablename__ = "experiments" __table_args__ = {"schema": "triage_metadata"} - experiment_hash = Column(String, primary_key=True) + experiment_hash = Column( + String, + primary_key=True + ) config = Column(JSONB) time_splits = Column(Integer) as_of_times = Column(Integer) @@ -70,15 +75,27 @@ class Experiment(Base): models_needed = Column(Integer) -class ExperimentRunStatus(enum.Enum): +class Retrain(Base): + __tablename__ = "retrain" + __table_args__ = {"schema": "triage_metadata"} + + retrain_hash = Column( + String, + primary_key=True + ) + config = Column(JSONB) + prediction_date = Column(DateTime) + + +class TriageRunStatus(enum.Enum): started = 1 completed = 2 failed = 3 -class ExperimentRun(Base): +class TriageRun(Base): - __tablename__ = "experiment_runs" + __tablename__ = "triage_runs" __table_args__ = {"schema": "triage_metadata"} run_id = Column("id", Integer, primary_key=True) @@ -87,10 +104,8 @@ class ExperimentRun(Base): git_hash = Column(String) triage_version = Column(String) python_version = Column(String) - experiment_hash = Column( - String, - ForeignKey("triage_metadata.experiments.experiment_hash") - ) + run_type = Column(String) + run_hash = Column(String) platform = Column(Text) os_user = Column(Text) working_directory = Column(Text) @@ -108,11 +123,10 @@ class ExperimentRun(Base): models_skipped = Column(Integer, default=0) models_errored = Column(Integer, default=0) last_updated_time = Column(DateTime, onupdate=datetime.datetime.now) - current_status = Column(Enum(ExperimentRunStatus)) + current_status = Column(Enum(TriageRunStatus)) stacktrace = Column(Text) random_seed = Column(Integer) - experiment_rel = relationship("Experiment") - + class Subset(Base): @@ -138,8 +152,8 @@ class ModelGroup(Base): class ListPrediction(Base): - __tablename__ = "list_predictions" - __table_args__ = {"schema": "triage_metadata"} + __tablename__ = "predictions" + __table_args__ = {"schema": "triage_production"} model_id = Column( Integer, ForeignKey("triage_metadata.models.model_id"), primary_key=True @@ -147,14 +161,30 @@ class ListPrediction(Base): entity_id = Column(BigInteger, primary_key=True) as_of_date = Column(DateTime, primary_key=True) score = Column(Numeric) - rank_abs = Column(Integer) - rank_pct = Column(Float) + label_value = Column(Integer) + rank_abs_no_ties = Column(Integer) + rank_abs_with_ties = Column(Integer) + rank_pct_no_ties = Column(Float) + rank_pct_with_ties = Column(Float) matrix_uuid = Column(Text) test_label_timespan = Column(Interval) model_rel = relationship("Model") +class ListPredictionMetadata(Base): + __tablename__ = "prediction_metadata" + __table_args__ = {"schema": "triage_production"} + + model_id = Column( + Integer, ForeignKey("triage_metadata.models.model_id"), primary_key=True + ) + matrix_uuid = Column(Text, ForeignKey("triage_metadata.matrices.matrix_uuid"), primary_key=True) + tiebreaker_ordering = Column(Text) + random_seed = Column(Integer) + predictions_saved = Column(Boolean) + + class ExperimentMatrix(Base): __tablename__ = "experiment_matrices" __table_args__ = {"schema": "triage_metadata"} @@ -205,11 +235,8 @@ class Model(Base): model_comment = Column(Text) batch_comment = Column(Text) config = Column(JSON) - built_by_experiment = Column( - String, ForeignKey("triage_metadata.experiments.experiment_hash") - ) - built_in_experiment_run = Column( - Integer, ForeignKey("triage_metadata.experiment_runs.id") + built_in_triage_run = Column( + Integer, ForeignKey("triage_metadata.triage_runs.id"), nullable=True ) train_end_time = Column(DateTime) test = Column(Boolean) @@ -244,6 +271,20 @@ class ExperimentModel(Base): experiment_rel = relationship("Experiment") +class RetrainModel(Base): + __tablename__ = "retrain_models" + __table_args__ = {"schema": "triage_metadata"} + + retrain_hash = Column( + String, + ForeignKey("triage_metadata.retrain.retrain_hash"), + primary_key=True + ) + model_hash = Column(String, primary_key=True) + model_rel = relationship("Model", primaryjoin=(Model.model_hash == model_hash), foreign_keys=model_hash) + retrain_rel = relationship("Retrain") + + class FeatureImportance(Base): __tablename__ = "feature_importances" diff --git a/src/triage/experiments/validate.py b/src/triage/experiments/validate.py index 736feaafe..1e583e249 100644 --- a/src/triage/experiments/validate.py +++ b/src/triage/experiments/validate.py @@ -695,7 +695,7 @@ def _run(self, model_group_keys, user_metadata): ) ) classifier_keys = ["class_path", "parameters"] - # planner_keys are defined in architect.Planner._make_metadata + # planner_keys are defined in architect.Planner.make_metadata planner_keys = [ "feature_start_time", "end_time", diff --git a/src/triage/predictlist/__init__.py b/src/triage/predictlist/__init__.py new file mode 100644 index 000000000..29050ec9b --- /dev/null +++ b/src/triage/predictlist/__init__.py @@ -0,0 +1,610 @@ +from triage.component.results_schema import upgrade_db, Retrain, TriageRun, TriageRunStatus +from triage.component.architect.entity_date_table_generators import EntityDateTableGenerator, DEFAULT_ACTIVE_STATE +from triage.component.architect.features import ( + FeatureGenerator, + FeatureDictionaryCreator, + FeatureGroupCreator, + FeatureGroupMixer, +) +from triage.component.architect.feature_group_creator import FeatureGroup +from triage.component.architect.builders import MatrixBuilder +from triage.component.architect.planner import Planner +from triage.component.architect.label_generators import LabelGenerator +from triage.component.timechop import Timechop +from triage.component.catwalk.storage import ModelStorageEngine, ProjectStorage +from triage.component.catwalk import ModelTrainer +from triage.component.catwalk.model_trainers import flatten_grid_config +from triage.component.catwalk.predictors import Predictor +from triage.component.catwalk.utils import retrieve_model_hash_from_id, filename_friendly_hash, retrieve_experiment_seed_from_run_id +from triage.util.conf import convert_str_to_relativedelta, dt_from_str +from triage.util.db import scoped_session, get_for_update +from triage.util.introspection import classpath +from triage.tracking import ( + infer_git_hash, + infer_ec2_instance_type, + infer_installed_libraries, + infer_python_version, + infer_triage_version, + infer_log_location, + +) +from .utils import ( + experiment_config_from_model_id, + experiment_config_from_model_group_id, + get_model_group_info, + train_matrix_info_from_model_id, + get_feature_names, + get_feature_needs_imputation_in_train, + get_feature_needs_imputation_in_production, + associate_models_with_retrain, + save_retrain_and_get_hash, + get_retrain_config_from_model_id, + temporal_params_from_matrix_metadata, +) + + +from collections import OrderedDict +import json +import random +import platform +import getpass +import os +from datetime import datetime + +import verboselogs, logging +logger = verboselogs.VerboseLogger(__name__) + + + +def predict_forward_with_existed_model(db_engine, project_path, model_id, as_of_date): + """Predict forward given model_id and as_of_date and store the prediction in database + + Args: + db_engine (sqlalchemy.db.engine) + project_storage (catwalk.storage.ProjectStorage) + model_id (int) The id of a given model in the database + as_of_date (string) a date string like "YYYY-MM-DD" + """ + logger.spam("In PREDICT LIST................") + upgrade_db(db_engine=db_engine) + project_storage = ProjectStorage(project_path) + matrix_storage_engine = project_storage.matrix_storage_engine() + # 1. Get feature and cohort config from database + (train_matrix_uuid, matrix_metadata) = train_matrix_info_from_model_id(db_engine, model_id) + experiment_config = experiment_config_from_model_id(db_engine, model_id) + + # 2. Generate cohort + cohort_table_name = f"triage_production.cohort_{experiment_config['cohort_config']['name']}" + cohort_table_generator = EntityDateTableGenerator( + db_engine=db_engine, + query=experiment_config['cohort_config']['query'], + entity_date_table_name=cohort_table_name + ) + cohort_table_generator.generate_entity_date_table(as_of_dates=[dt_from_str(as_of_date)]) + + # 3. Generate feature aggregations + feature_generator = FeatureGenerator( + db_engine=db_engine, + features_schema_name="triage_production", + feature_start_time=experiment_config['temporal_config']['feature_start_time'], + ) + collate_aggregations = feature_generator.aggregations( + feature_aggregation_config=experiment_config['feature_aggregations'], + feature_dates=[as_of_date], + state_table=cohort_table_name + ) + feature_generator.process_table_tasks( + feature_generator.generate_all_table_tasks( + collate_aggregations, + task_type='aggregation' + ) + ) + + # 4. Reconstruct feature disctionary from feature_names and generate imputation + + reconstructed_feature_dict = FeatureGroup() + imputation_table_tasks = OrderedDict() + + for aggregation in collate_aggregations: + feature_group, feature_names = get_feature_names(aggregation, matrix_metadata) + reconstructed_feature_dict[feature_group] = feature_names + + # Make sure that the features imputed in training should also be imputed in production + + features_imputed_in_train = get_feature_needs_imputation_in_train(aggregation, feature_names) + + features_imputed_in_production = get_feature_needs_imputation_in_production(aggregation, db_engine) + + total_impute_cols = set(features_imputed_in_production) | set(features_imputed_in_train) + total_nonimpute_cols = set(f for f in set(feature_names) if '_imp' not in f) - total_impute_cols + + task_generator = feature_generator._generate_imp_table_tasks_for + + imputation_table_tasks.update(task_generator( + aggregation, + impute_cols=list(total_impute_cols), + nonimpute_cols=list(total_nonimpute_cols) + ) + ) + feature_generator.process_table_tasks(imputation_table_tasks) + + # 5. Build matrix + db_config = { + "features_schema_name": "triage_production", + "labels_schema_name": "public", + "cohort_table_name": cohort_table_name, + } + + matrix_builder = MatrixBuilder( + db_config=db_config, + matrix_storage_engine=matrix_storage_engine, + engine=db_engine, + experiment_hash=None, + replace=True, + ) + + feature_start_time = experiment_config['temporal_config']['feature_start_time'] + label_name = experiment_config['label_config']['name'] + label_type = 'binary' + cohort_name = experiment_config['cohort_config']['name'] + user_metadata = experiment_config['user_metadata'] + + # Use timechop to get the time definition for production + temporal_config = experiment_config["temporal_config"] + temporal_config.update(temporal_params_from_matrix_metadata(db_engine, model_id)) + timechopper = Timechop(**temporal_config) + prod_definitions = timechopper.define_test_matrices( + train_test_split_time=dt_from_str(as_of_date), + test_duration=temporal_config['test_durations'][0], + test_label_timespan=temporal_config['test_label_timespans'][0] + ) + + matrix_metadata = Planner.make_metadata( + prod_definitions[-1], + reconstructed_feature_dict, + label_name, + label_type, + cohort_name, + 'production', + feature_start_time, + user_metadata, + ) + + matrix_metadata['matrix_id'] = str(as_of_date) + f'_model_id_{model_id}' + '_risklist' + + matrix_uuid = filename_friendly_hash(matrix_metadata) + + matrix_builder.build_matrix( + as_of_times=[as_of_date], + label_name=label_name, + label_type=label_type, + feature_dictionary=reconstructed_feature_dict, + matrix_metadata=matrix_metadata, + matrix_uuid=matrix_uuid, + matrix_type="production", + ) + + # 6. Predict the risk score for production + predictor = Predictor( + model_storage_engine=project_storage.model_storage_engine(), + db_engine=db_engine, + rank_order='best' + ) + + predictor.predict( + model_id=model_id, + matrix_store=matrix_storage_engine.get_store(matrix_uuid), + misc_db_parameters={}, + train_matrix_columns=matrix_storage_engine.get_store(train_matrix_uuid).columns() + ) + + +class Retrainer: + """Given a model_group_id and prediction_date, retrain a model using the all the data till prediction_date + Args: + db_engine (sqlalchemy.engine) + project_path (string) + model_group_id (string) + """ + def __init__(self, db_engine, project_path, model_group_id): + self.retrain_hash = None + self.db_engine = db_engine + upgrade_db(db_engine=self.db_engine) + self.project_storage = ProjectStorage(project_path) + self.model_group_id = model_group_id + self.model_group_info = get_model_group_info(self.db_engine, self.model_group_id) + self.matrix_storage_engine = self.project_storage.matrix_storage_engine() + self.triage_run_id, self.experiment_config = experiment_config_from_model_group_id(self.db_engine, self.model_group_id) + + # This feels like it needs some refactoring since in some edge cases at least the test matrix temporal parameters + # might differ across models in the mdoel group (the training ones shouldn't), but this should probably work for + # the vast majorty of use cases... + self.experiment_config['temporal_config'].update(temporal_params_from_matrix_metadata(self.db_engine, self.model_group_info['model_id_last_split'])) + + # Since "testing" here is predicting forward to a single new date, the test_duration should always be '0day' + # (regardless of what it may have been before) + self.experiment_config['temporal_config']['test_durations'] = ['0day'] + + # These lists should now only contain one item (the value actually used for the last model in this group) + self.training_label_timespan = self.experiment_config['temporal_config']['training_label_timespans'][0] + self.test_label_timespan = self.experiment_config['temporal_config']['test_label_timespans'][0] + self.test_duration = self.experiment_config['temporal_config']['test_durations'][0] + self.feature_start_time=self.experiment_config['temporal_config']['feature_start_time'] + + self.label_name = self.experiment_config['label_config']['name'] + self.cohort_name = self.experiment_config['cohort_config']['name'] + self.user_metadata = self.experiment_config['user_metadata'] + + + self.feature_dictionary_creator = FeatureDictionaryCreator( + features_schema_name='triage_production', db_engine=self.db_engine + ) + self.label_generator = LabelGenerator( + label_name=self.experiment_config['label_config'].get("name", None), + query=self.experiment_config['label_config']["query"], + replace=True, + db_engine=self.db_engine, + ) + + self.labels_table_name = "labels_{}_{}_production".format( + self.experiment_config['label_config'].get('name', 'default'), + filename_friendly_hash(self.experiment_config['label_config']['query']) + ) + + self.feature_generator = FeatureGenerator( + db_engine=self.db_engine, + features_schema_name="triage_production", + feature_start_time=self.feature_start_time, + ) + + self.model_trainer = ModelTrainer( + experiment_hash=None, + model_storage_engine=ModelStorageEngine(self.project_storage), + db_engine=self.db_engine, + replace=True, + run_id=self.triage_run_id, + ) + + def get_temporal_config_for_retrain(self, prediction_date): + temporal_config = self.experiment_config['temporal_config'].copy() + temporal_config['feature_end_time'] = datetime.strftime(prediction_date, "%Y-%m-%d") + temporal_config['label_end_time'] = datetime.strftime( + prediction_date + convert_str_to_relativedelta(self.test_label_timespan), + "%Y-%m-%d") + # just needs to be bigger than the gap between the label start and end times + # to ensure we only get one time split for the retraining + temporal_config['model_update_frequency'] = '%syears' % ( + dt_from_str(temporal_config['label_end_time']).year - + dt_from_str(temporal_config['label_start_time']).year + 10 + ) + + return temporal_config + + def generate_all_labels(self, as_of_date): + self.label_generator.generate_all_labels( + labels_table=self.labels_table_name, + as_of_dates=[as_of_date], + label_timespans=[self.training_label_timespan] + ) + + def generate_entity_date_table(self, as_of_date, entity_date_table_name): + cohort_table_generator = EntityDateTableGenerator( + db_engine=self.db_engine, + query=self.experiment_config['cohort_config']['query'], + entity_date_table_name=entity_date_table_name + ) + cohort_table_generator.generate_entity_date_table(as_of_dates=[dt_from_str(as_of_date)]) + + def get_collate_aggregations(self, as_of_date, state_table): + collate_aggregations = self.feature_generator.aggregations( + feature_aggregation_config=self.experiment_config['feature_aggregations'], + feature_dates=[as_of_date], + state_table=state_table + ) + return collate_aggregations + + def get_feature_dict_and_imputation_task(self, collate_aggregations, model_id): + (train_matrix_uuid, matrix_metadata) = train_matrix_info_from_model_id(self.db_engine, model_id) + reconstructed_feature_dict = FeatureGroup() + imputation_table_tasks = OrderedDict() + for aggregation in collate_aggregations: + feature_group, feature_names = get_feature_names(aggregation, matrix_metadata) + reconstructed_feature_dict[feature_group] = feature_names + # Make sure that the features imputed in training should also be imputed in production + + features_imputed_in_train = get_feature_needs_imputation_in_train(aggregation, feature_names) + + features_imputed_in_production = get_feature_needs_imputation_in_production(aggregation, self.db_engine) + + total_impute_cols = set(features_imputed_in_production) | set(features_imputed_in_train) + total_nonimpute_cols = set(f for f in set(feature_names) if '_imp' not in f) - total_impute_cols + + task_generator = self.feature_generator._generate_imp_table_tasks_for + + imputation_table_tasks.update(task_generator( + aggregation, + impute_cols=list(total_impute_cols), + nonimpute_cols=list(total_nonimpute_cols) + ) + ) + return reconstructed_feature_dict, imputation_table_tasks + + def retrain(self, prediction_date): + """Retrain a model by going back one split from prediction_date, so the as_of_date for training would be (prediction_date - training_label_timespan) + + Args: + prediction_date(str) + """ + # Retrain config and hash + retrain_config = { + "model_group_id": self.model_group_id, + "prediction_date": prediction_date, + "test_label_timespan": self.test_label_timespan, + "test_duration": self.test_duration, + + } + self.retrain_hash = save_retrain_and_get_hash(retrain_config, self.db_engine) + + with get_for_update(self.db_engine, Retrain, self.retrain_hash) as retrain: + retrain.prediction_date = prediction_date + + + # Timechop + prediction_date = dt_from_str(prediction_date) + temporal_config = self.get_temporal_config_for_retrain(prediction_date) + timechopper = Timechop(**temporal_config) + chops = timechopper.chop_time() + assert len(chops) == 1 + chops_train_matrix = chops[0]['train_matrix'] + as_of_date = datetime.strftime(chops_train_matrix['last_as_of_time'], "%Y-%m-%d") + retrain_definition = { + 'first_as_of_time': chops_train_matrix['first_as_of_time'], + 'last_as_of_time': chops_train_matrix['last_as_of_time'], + 'matrix_info_end_time': chops_train_matrix['matrix_info_end_time'], + 'as_of_times': [as_of_date], + 'training_label_timespan': chops_train_matrix['training_label_timespan'], + 'max_training_history': chops_train_matrix['max_training_history'], + 'training_as_of_date_frequency': chops_train_matrix['training_as_of_date_frequency'], + } + + # Set ExperimentRun + run = TriageRun( + start_time=datetime.now(), + git_hash=infer_git_hash(), + triage_version=infer_triage_version(), + python_version=infer_python_version(), + run_type="retrain", + run_hash=self.retrain_hash, + last_updated_time=datetime.now(), + current_status=TriageRunStatus.started, + installed_libraries=infer_installed_libraries(), + platform=platform.platform(), + os_user=getpass.getuser(), + working_directory=os.getcwd(), + ec2_instance_type=infer_ec2_instance_type(), + log_location=infer_log_location(), + experiment_class_path=classpath(self.__class__), + random_seed = retrieve_experiment_seed_from_run_id(self.db_engine, self.triage_run_id), + ) + run_id = None + with scoped_session(self.db_engine) as session: + session.add(run) + session.commit() + run_id = run.run_id + if not run_id: + raise ValueError("Failed to retrieve run_id from saved row") + + # set ModelTrainer's run_id and experiment_hash for Retrain run + self.model_trainer.run_id = run_id + self.model_trainer.experiment_hash = self.retrain_hash + + # 1. Generate all labels + self.generate_all_labels(as_of_date) + + # 2. Generate cohort + cohort_table_name = f"triage_production.cohort_{self.experiment_config['cohort_config']['name']}_retrain" + self.generate_entity_date_table(as_of_date, cohort_table_name) + + # 3. Generate feature aggregations + collate_aggregations = self.get_collate_aggregations(as_of_date, cohort_table_name) + feature_aggregation_table_tasks = self.feature_generator.generate_all_table_tasks( + collate_aggregations, + task_type='aggregation' + ) + self.feature_generator.process_table_tasks(feature_aggregation_table_tasks) + + # 4. Reconstruct feature disctionary from feature_names and generate imputation + reconstructed_feature_dict, imputation_table_tasks = self.get_feature_dict_and_imputation_task( + collate_aggregations, + self.model_group_info['model_id_last_split'], + ) + feature_group_creator = FeatureGroupCreator(self.experiment_config['feature_group_definition']) + feature_group_mixer = FeatureGroupMixer(["all"]) + feature_group_dict = feature_group_mixer.generate( + feature_group_creator.subsets(reconstructed_feature_dict) + )[0] + self.feature_generator.process_table_tasks(imputation_table_tasks) + # 5. Build new matrix + db_config = { + "features_schema_name": "triage_production", + "labels_schema_name": "public", + "cohort_table_name": cohort_table_name, + "labels_table_name": self.labels_table_name, + } + + matrix_builder = MatrixBuilder( + db_config=db_config, + matrix_storage_engine=self.matrix_storage_engine, + engine=self.db_engine, + experiment_hash=None, + replace=True, + ) + new_matrix_metadata = Planner.make_metadata( + matrix_definition=retrain_definition, + feature_dictionary=feature_group_dict, + label_name=self.label_name, + label_type='binary', + cohort_name=self.cohort_name, + matrix_type='train', + feature_start_time=dt_from_str(self.feature_start_time), + user_metadata=self.user_metadata, + ) + + new_matrix_metadata['matrix_id'] = "_".join( + [ + self.label_name, + 'binary', + str(as_of_date), + 'retrain', + ] + ) + + matrix_uuid = filename_friendly_hash(new_matrix_metadata) + matrix_builder.build_matrix( + as_of_times=[as_of_date], + label_name=self.label_name, + label_type='binary', + feature_dictionary=feature_group_dict, + matrix_metadata=new_matrix_metadata, + matrix_uuid=matrix_uuid, + matrix_type="train", + ) + retrain_model_comment = 'retrain_' + str(datetime.now()) + + misc_db_parameters = { + 'train_end_time': dt_from_str(as_of_date), + 'test': False, + 'train_matrix_uuid': matrix_uuid, + 'training_label_timespan': self.training_label_timespan, + 'model_comment': retrain_model_comment, + } + + # get the random seed from the last split + last_split_train_matrix_uuid, last_split_matrix_metadata = train_matrix_info_from_model_id( + self.db_engine, + model_id=self.model_group_info['model_id_last_split'] + ) + + random_seed = self.model_trainer.get_or_generate_random_seed( + model_group_id=self.model_group_id, + matrix_metadata=last_split_matrix_metadata, + train_matrix_uuid=last_split_train_matrix_uuid + ) + + # create retrain model hash + retrain_model_hash = self.model_trainer._model_hash( + self.matrix_storage_engine.get_store(matrix_uuid).metadata, + class_path=self.model_group_info['model_type'], + parameters=self.model_group_info['hyperparameters'], + random_seed=random_seed, + ) + + associate_models_with_retrain(self.retrain_hash, (retrain_model_hash, ), self.db_engine) + + retrain_model_id = self.model_trainer.process_train_task( + matrix_store=self.matrix_storage_engine.get_store(matrix_uuid), + class_path=self.model_group_info['model_type'], + parameters=self.model_group_info['hyperparameters'], + model_hash=retrain_model_hash, + misc_db_parameters=misc_db_parameters, + random_seed=random_seed, + retrain=True, + model_group_id=self.model_group_id + ) + + self.retrain_model_hash = retrieve_model_hash_from_id(self.db_engine, retrain_model_id) + self.retrain_matrix_uuid = matrix_uuid + self.retrain_model_id = retrain_model_id + return {'retrain_model_comment': retrain_model_comment, 'retrain_model_id': retrain_model_id} + + def predict(self, prediction_date): + """Predict forward by creating a matrix using as_of_date = prediction_date and applying the retrain model on it + + Args: + prediction_date(str) + """ + cohort_table_name = f"triage_production.cohort_{self.experiment_config['cohort_config']['name']}_predict" + + # 1. Generate cohort + self.generate_entity_date_table(prediction_date, cohort_table_name) + + # 2. Generate feature aggregations + collate_aggregations = self.get_collate_aggregations(prediction_date, cohort_table_name) + self.feature_generator.process_table_tasks( + self.feature_generator.generate_all_table_tasks( + collate_aggregations, + task_type='aggregation' + ) + ) + # 3. Reconstruct feature disctionary from feature_names and generate imputation + reconstructed_feature_dict, imputation_table_tasks = self.get_feature_dict_and_imputation_task( + collate_aggregations, + self.retrain_model_id + ) + self.feature_generator.process_table_tasks(imputation_table_tasks) + + # 4. Build matrix + db_config = { + "features_schema_name": "triage_production", + "labels_schema_name": "public", + "cohort_table_name": cohort_table_name, + } + + matrix_builder = MatrixBuilder( + db_config=db_config, + matrix_storage_engine=self.matrix_storage_engine, + engine=self.db_engine, + experiment_hash=None, + replace=True, + ) + # Use timechop to get the time definition for production + temporal_config = self.get_temporal_config_for_retrain(dt_from_str(prediction_date)) + timechopper = Timechop(**temporal_config) + + retrain_config = get_retrain_config_from_model_id(self.db_engine, self.retrain_model_id) + + prod_definitions = timechopper.define_test_matrices( + train_test_split_time=dt_from_str(prediction_date), + test_duration=retrain_config['test_duration'], + test_label_timespan=retrain_config['test_label_timespan'] + ) + last_split_definition = prod_definitions[-1] + matrix_metadata = Planner.make_metadata( + matrix_definition=last_split_definition, + feature_dictionary=reconstructed_feature_dict, + label_name=self.label_name, + label_type='binary', + cohort_name=self.cohort_name, + matrix_type='production', + feature_start_time=self.feature_start_time, + user_metadata=self.user_metadata, + ) + + matrix_metadata['matrix_id'] = str(prediction_date) + f'_model_id_{self.retrain_model_id}' + '_risklist' + + matrix_uuid = filename_friendly_hash(matrix_metadata) + + matrix_builder.build_matrix( + as_of_times=[prediction_date], + label_name=self.label_name, + label_type='binary', + feature_dictionary=reconstructed_feature_dict, + matrix_metadata=matrix_metadata, + matrix_uuid=matrix_uuid, + matrix_type="production", + ) + + # 5. Predict the risk score for production + predictor = Predictor( + model_storage_engine=self.project_storage.model_storage_engine(), + db_engine=self.db_engine, + rank_order='best' + ) + + predictor.predict( + model_id=self.retrain_model_id, + matrix_store=self.matrix_storage_engine.get_store(matrix_uuid), + misc_db_parameters={}, + train_matrix_columns=self.matrix_storage_engine.get_store(self.retrain_matrix_uuid).columns(), + ) + self.predict_matrix_uuid = matrix_uuid diff --git a/src/triage/predictlist/utils.py b/src/triage/predictlist/utils.py new file mode 100644 index 000000000..9b5eeaf51 --- /dev/null +++ b/src/triage/predictlist/utils.py @@ -0,0 +1,208 @@ +from triage.component.results_schema import RetrainModel, Retrain +from triage.component.catwalk.utils import db_retry, filename_friendly_hash + +import re +from sqlalchemy.orm import sessionmaker +import verboselogs, logging +logger = verboselogs.VerboseLogger(__name__) + + +def experiment_config_from_model_id(db_engine, model_id): + """Get original experiment config from model_id + Args: + db_engine (sqlalchemy.db.engine) + model_id (int) The id of a given model in the database + + Returns: (dict) experiment config + """ + get_experiment_query = '''select experiments.config + from triage_metadata.triage_runs + join triage_metadata.models on (triage_runs.id = models.built_in_triage_run) + join triage_metadata.experiments + on (experiments.experiment_hash = triage_runs.run_hash and triage_runs.run_type='experiment') + where model_id = %s + ''' + (config,) = db_engine.execute(get_experiment_query, model_id).first() + return config + + +def experiment_config_from_model_group_id(db_engine, model_group_id): + """Get original experiment config from model_id + Args: + db_engine (sqlalchemy.db.engine) + model_id (int) The id of a given model in the database + + Returns: (dict) experiment config + """ + get_experiment_query = ''' + select triage_runs.id as run_id, experiments.config + from triage_metadata.triage_runs + join triage_metadata.models + on (triage_runs.id = models.built_in_triage_run) + join triage_metadata.experiments + on (experiments.experiment_hash = triage_runs.run_hash and triage_runs.run_type='experiment') + where model_group_id = %s + order by triage_runs.start_time desc + ''' + (run_id, config) = db_engine.execute(get_experiment_query, model_group_id).first() + return run_id, config + + +def get_model_group_info(db_engine, model_group_id): + query = """ + SELECT model_group_id, model_type, hyperparameters, model_id as model_id_last_split + FROM triage_metadata.models + WHERE model_group_id = %s + ORDER BY train_end_time DESC + """ + model_group_info = db_engine.execute(query, model_group_id).fetchone() + return dict(model_group_info) + + +def train_matrix_info_from_model_id(db_engine, model_id): + """Get original train matrix information from model_id + Args: + db_engine (sqlalchemy.db.engine) + model_id (int) The id of a given model in the database + + Returns: (str, dict) matrix uuid and matrix metadata + """ + get_train_matrix_query = """ + select matrix_uuid, matrices.matrix_metadata + from triage_metadata.matrices + join triage_metadata.models on (models.train_matrix_uuid = matrices.matrix_uuid) + where model_id = %s + """ + return db_engine.execute(get_train_matrix_query, model_id).first() + + +def test_matrix_info_from_model_id(db_engine, model_id): + """Get original test matrix information from model_id + + Note: because a model may have been tested on multiple matrices, this + chooses the matrix associated with the most recently run experiment + (then randomly if multiple test matrices are associated with the model_id + in that experiment). Generally, this will be an edge case, but may be + worth considering providing more control over which to choose here. + + Args: + db_engine (sqlalchemy.db.engine) + model_id (int) The id of a given model in the database + + Returns: (str, dict) matrix uuid and matrix metadata + """ + get_test_matrix_query = """ + select mat.matrix_uuid, mat.matrix_metadata + from triage_metadata.matrices mat + join test_results.prediction_metadata pm on (pm.matrix_uuid = mat.matrix_uuid) + join triage_metadata.triage_runs tr + on (mat.built_by_experiment = tr.run_hash AND tr.run_type='experiment') + where pm.model_id = %s + order by start_time DESC, RANDOM() + limit 1 + """ + return db_engine.execute(get_test_matrix_query, model_id).first() + + + +def temporal_params_from_matrix_metadata(db_engine, model_id): + """Read temporal parameters associated with model training/testing from the associated + matrices. Because a grid of multiple values may be provided in the experiment config + for these parameters, we need to find the specific values that were actually used for + the given model at runtime. + + Args: + db_engine (sqlalchemy.db.engine) + model_id (int) The id of a given model in the database + + Returns: (dict) The parameters for use in a temporal config for timechop + """ + train_uuid, train_metadata = train_matrix_info_from_model_id(db_engine, model_id) + test_uuid, test_metadata = test_matrix_info_from_model_id(db_engine, model_id) + + temporal_params = {} + + temporal_params['training_as_of_date_frequencies'] = train_metadata['training_as_of_date_frequency'] + temporal_params['test_as_of_date_frequencies'] = test_metadata['test_as_of_date_frequency'] + temporal_params['max_training_histories'] = [ train_metadata['max_training_history'] ] + temporal_params['test_durations'] = [ test_metadata['test_duration'] ] + temporal_params['training_label_timespans'] = [ train_metadata.get('training_label_timespan', train_metadata['label_timespan']) ] + temporal_params['test_label_timespans'] = [ test_metadata.get('test_label_timespan', test_metadata['label_timespan']) ] + + return temporal_params + +def get_feature_names(aggregation, matrix_metadata): + """Returns a feature group name and a list of feature names from a SpacetimeAggregation object""" + feature_prefix = aggregation.prefix + logger.spam("Feature prefix = %s", feature_prefix) + feature_group = aggregation.get_table_name(imputed=True).split('.')[1].replace('"', '') + logger.spam("Feature group = %s", feature_group) + feature_names_in_group = [f for f in matrix_metadata['feature_names'] if re.match(f'\\A{feature_prefix}_', f)] + logger.spam("Feature names in group = %s", feature_names_in_group) + + return feature_group, feature_names_in_group + + +def get_feature_needs_imputation_in_train(aggregation, feature_names): + """Returns features that needs imputation from training data + Args: + aggregation (SpacetimeAggregation) + feature_names (list) A list of feature names + """ + features_imputed_in_train = [ + f for f in set(feature_names) + if not f.endswith('_imp') + and aggregation.imputation_flag_base(f) + '_imp' in feature_names + ] + logger.spam("Features imputed in train = %s", features_imputed_in_train) + return features_imputed_in_train + + +def get_feature_needs_imputation_in_production(aggregation, db_engine): + """Returns features that needs imputation from triage_production + Args: + aggregation (SpacetimeAggregation) + db_engine (sqlalchemy.db.engine) + """ + with db_engine.begin() as conn: + nulls_results = conn.execute(aggregation.find_nulls()) + + null_counts = nulls_results.first().items() + features_imputed_in_production = [col for (col, val) in null_counts if val is not None and val > 0] + + return features_imputed_in_production + + +def get_retrain_config_from_model_id(db_engine, model_id): + query = """ + SELECT re.config FROM triage_metadata.models m + LEFT JOIN triage_metadata.triage_runs r + ON m.built_in_triage_run = r.id + LEFT JOIN triage_metadata.retrain re + ON (re.retrain_hash = r.run_hash and r.run_type='retrain') + WHERE m.model_id = %s; + """ + + (config,) = db_engine.execute(query, model_id).first() + return config + + +@db_retry +def associate_models_with_retrain(retrain_hash, model_hashes, db_engine): + session = sessionmaker(bind=db_engine)() + for model_hash in model_hashes: + session.merge(RetrainModel(retrain_hash=retrain_hash, model_hash=model_hash)) + session.commit() + session.close() + logger.spam("Associated models with retrain in database") + +@db_retry +def save_retrain_and_get_hash(config, db_engine): + retrain_hash = filename_friendly_hash(config) + session = sessionmaker(bind=db_engine)() + session.merge(Retrain(retrain_hash=retrain_hash, config=config)) + session.commit() + session.close() + return retrain_hash + + diff --git a/src/triage/tracking.py b/src/triage/tracking.py index 0803d6b5e..380c1507b 100644 --- a/src/triage/tracking.py +++ b/src/triage/tracking.py @@ -21,7 +21,7 @@ pip_freeze = None -from triage.component.results_schema import ExperimentRun, ExperimentRunStatus +from triage.component.results_schema import TriageRun, TriageRunStatus def infer_git_hash(): @@ -100,7 +100,7 @@ def initialize_tracking_and_get_run_id( experiment_kwargs, db_engine ): - """Create a row in the ExperimentRun table with some initial info and return the created run_id + """Create a row in the TriageRun table with some initial info and return the created run_id Args: experiment_hash (str) An experiment hash that exists in the experiments table @@ -115,14 +115,15 @@ def initialize_tracking_and_get_run_id( k: (classpath(v) if isinstance(v, type) else v) for k, v in experiment_kwargs.items() } - run = ExperimentRun( + run = TriageRun( start_time=datetime.datetime.now(), git_hash=infer_git_hash(), triage_version=infer_triage_version(), python_version=infer_python_version(), - experiment_hash=experiment_hash, + run_type="experiment", + run_hash=experiment_hash, last_updated_time=datetime.datetime.now(), - current_status=ExperimentRunStatus.started, + current_status=TriageRunStatus.started, installed_libraries=infer_installed_libraries(), platform=platform.platform(), os_user=getpass.getuser(), @@ -144,7 +145,7 @@ def initialize_tracking_and_get_run_id( def get_run_for_update(db_engine, run_id): - """Yields an ExperimentRun at the given run_id for update + """Yields an TriageRun at the given run_id for update Will kick the last_update_time timestamp of the row each time. @@ -152,7 +153,7 @@ def get_run_for_update(db_engine, run_id): db_engine (sqlalchemy.engine) run_id (int) The identifier/primary key of the run """ - return get_for_update(db_engine, ExperimentRun, run_id) + return get_for_update(db_engine, TriageRun, run_id) def experiment_entrypoint(entrypoint_func): @@ -161,8 +162,8 @@ def experiment_entrypoint(entrypoint_func): To update the database, it requires the instance of the wrapped method to have a db_engine and run_id. - Upon method entry, will update the ExperimentRun row with the wrapped method name. - Upon method exit, will update the ExperimentRun row with the status (either failed or completed) + Upon method entry, will update the TriageRun row with the wrapped method name. + Upon method exit, will update the TriageRun row with the status (either failed or completed) """ @wraps(entrypoint_func) def with_entrypoint(self, *args, **kwargs): @@ -174,12 +175,12 @@ def with_entrypoint(self, *args, **kwargs): return_value = entrypoint_func(self, *args, **kwargs) except Exception as exc: with get_run_for_update(self.db_engine, self.run_id) as run_obj: - run_obj.current_status = ExperimentRunStatus.failed + run_obj.current_status = TriageRunStatus.failed run_obj.stacktrace = str(exc) raise exc with get_run_for_update(self.db_engine, self.run_id) as run_obj: - run_obj.current_status = ExperimentRunStatus.completed + run_obj.current_status = TriageRunStatus.completed return return_value @@ -187,7 +188,7 @@ def with_entrypoint(self, *args, **kwargs): def increment_field(field, run_id, db_engine): - """Increment an ExperimentRun's named field. + """Increment an TriageRun's named field. Expects that the field is an integer in the database. @@ -201,8 +202,8 @@ def increment_field(field, run_id, db_engine): with scoped_session(db_engine) as session: # Use an update query instead of a session merge so it happens in one atomic query # and protect against race conditions - session.query(ExperimentRun).filter_by(run_id=run_id).update({ - field: getattr(ExperimentRun, field) + 1, + session.query(TriageRun).filter_by(run_id=run_id).update({ + field: getattr(TriageRun, field) + 1, 'last_updated_time': datetime.datetime.now() }) @@ -230,7 +231,7 @@ def record_model_building_started(run_id, db_engine): def built_matrix(run_id, db_engine): - """Increment the matrix build counter for the ExperimentRun + """Increment the matrix build counter for the TriageRun Args: run_id (int) The identifier/primary key of the run @@ -240,7 +241,7 @@ def built_matrix(run_id, db_engine): def skipped_matrix(run_id, db_engine): - """Increment the matrix skip counter for the ExperimentRun + """Increment the matrix skip counter for the TriageRun Args: run_id (int) The identifier/primary key of the run @@ -250,7 +251,7 @@ def skipped_matrix(run_id, db_engine): def errored_matrix(run_id, db_engine): - """Increment the matrix error counter for the ExperimentRun + """Increment the matrix error counter for the TriageRun Args: run_id (int) The identifier/primary key of the run @@ -260,7 +261,7 @@ def errored_matrix(run_id, db_engine): def built_model(run_id, db_engine): - """Increment the model build counter for the ExperimentRun + """Increment the model build counter for the TriageRun Args: run_id (int) The identifier/primary key of the run @@ -270,7 +271,7 @@ def built_model(run_id, db_engine): def skipped_model(run_id, db_engine): - """Increment the model skip counter for the ExperimentRun + """Increment the model skip counter for the TriageRun Args: run_id (int) The identifier/primary key of the run @@ -280,7 +281,7 @@ def skipped_model(run_id, db_engine): def errored_model(run_id, db_engine): - """Increment the model error counter for the ExperimentRun + """Increment the model error counter for the TriageRun Args: run_id (int) The identifier/primary key of the run