diff --git a/tests/test_complex_pipeline.py b/tests/test_complex_pipeline.py index daf9333b..60685c42 100644 --- a/tests/test_complex_pipeline.py +++ b/tests/test_complex_pipeline.py @@ -6,7 +6,7 @@ from sqlalchemy.sql.sqltypes import Integer, String from datapipe.compute import Catalog, Pipeline, Table, build_compute, run_steps -from datapipe.datatable import DataStore +from datapipe.datatable import DataStore, DataTable from datapipe.step.batch_generate import BatchGenerate from datapipe.step.batch_transform import BatchTransform from datapipe.store.database import TableStoreDB @@ -429,3 +429,112 @@ def test_complex_transform_with_many_recordings_N1000(dbconn): @pytest.mark.skip(reason="fails on sqlite") def test_complex_transform_with_many_recordings_N10000(dbconn): complex_transform_with_many_recordings(dbconn, N=10000) + + +def test_applying_prediction_on_best_model_only(dbconn) -> None: + # N = 100 + N = 5 + ds = DataStore(dbconn, create_meta_table=True) + + catalog = Catalog( + { + "tbl_image": Table( + store=TableStoreDB( + dbconn, + "tbl_image", + [ + Column("image_id", Integer, primary_key=True), + ], + True, + ) + ), + "tbl_model": Table( + store=TableStoreDB( + dbconn, + "tbl_model", + [ + Column("model_id", Integer, primary_key=True), + ], + True, + ) + ), + "tbl_best_model": Table( + store=TableStoreDB( + dbconn, + "tbl_best_model", + [ + Column("model_id", Integer, primary_key=True), + ], + True, + ) + ), + "tbl_prediction": Table( + store=TableStoreDB( + dbconn, + "tbl_prediction", + [ + Column("image_id", Integer, primary_key=True), + Column("model_id", Integer, primary_key=True), + ], + True, + ) + ), + } + ) + + test_df__image = pd.DataFrame({"image_id": range(N)}) + test_df__model = pd.DataFrame({"model_id": [0, 1, 2, 3, 4]}) + test_df__best_model = pd.DataFrame({"model_id": [4]}) + + def inference_only_on_best_model( + df__image: pd.DataFrame, + df__model: pd.DataFrame, + df__best_model: pd.DataFrame, + idx: IndexDF, + ): + df__prediction = pd.merge(df__image, df__model, how="cross") + return df__prediction[["image_id", "model_id"]] + + pipeline = Pipeline( + [ + BatchTransform( + func=inference_only_on_best_model, + inputs=[ + "tbl_image", # image_id + "tbl_model", # model_id + Required("tbl_best_model"), # model_id + ], + outputs=["tbl_prediction"], + transform_keys=["image_id", "model_id"], + ), + ] + ) + + steps = build_compute(ds, catalog, pipeline) + + ds.get_table("tbl_image").store_chunk(test_df__image) + ds.get_table("tbl_model").store_chunk(test_df__model) + ds.get_table("tbl_best_model").store_chunk(test_df__best_model) + + run_steps(ds, steps) + + test__df_prediction = pd.DataFrame({"image_id": range(N), "model_id": [4] * N}) + assert_df_equal( + ds.get_table("tbl_prediction").get_data(), + test__df_prediction, + index_cols=["image_id", "model_id"], + ) + + test_df__new_best_model = pd.DataFrame({"model_id": [3]}) + dt__tbl_best_model: DataTable = ds.get_table("tbl_best_model") + dt__tbl_best_model.delete_by_idx(cast(IndexDF, dt__tbl_best_model.get_data())) + dt__tbl_best_model.store_chunk(test_df__new_best_model) + + run_steps(ds, steps) + + test__new_df_prediction = pd.DataFrame({"image_id": range(N), "model_id": [3] * N}) + assert_df_equal( + ds.get_table("tbl_prediction").get_data(), + test__new_df_prediction, + index_cols=["image_id", "model_id"], + )