From b254b2b4c1f75e0adb828f149162232d7595b4f0 Mon Sep 17 00:00:00 2001 From: Zebedee Nicholls Date: Tue, 12 Oct 2021 11:58:21 +1100 Subject: [PATCH] Format --- .../database_crunchers/rms_closest.py | 34 +++++++++++++------ .../crunchers/test_cruncher_rms_closest.py | 34 +++++++++++++++---- 2 files changed, 50 insertions(+), 18 deletions(-) diff --git a/src/silicone/database_crunchers/rms_closest.py b/src/silicone/database_crunchers/rms_closest.py index 66d11fe3..fcc2f182 100644 --- a/src/silicone/database_crunchers/rms_closest.py +++ b/src/silicone/database_crunchers/rms_closest.py @@ -199,7 +199,9 @@ def filler(in_iamdf): return filler - def infill_multiple(self, to_infill, variable_followers, variable_leaders, weighting=None): + def infill_multiple( + self, to_infill, variable_followers, variable_leaders, weighting=None + ): """ Infill multiple variables simultaneously @@ -244,13 +246,11 @@ def infill_multiple(self, to_infill, variable_followers, variable_leaders, weigh db_lead_ts = db_lead.timeseries() db_lead_ts.index = db_lead_ts.index.rename( - ["model_db", "scenario_db"], - level=["model", "scenario"] + ["model_db", "scenario_db"], level=["model", "scenario"] ) to_infill_lead_ts = to_infill_lead.timeseries() to_infill_lead_ts.index = to_infill_lead_ts.index.rename( - ["model_lead", "scenario_lead"], - level=["model", "scenario"] + ["model_lead", "scenario_lead"], level=["model", "scenario"] ) common_cols = _get_common_cols(db_lead_ts, to_infill_lead_ts) @@ -262,24 +262,36 @@ def infill_multiple(self, to_infill, variable_followers, variable_leaders, weigh rms = ((db_lead_ts - to_infill_lead_ts) ** 2).mean(axis=1) ** 0.5 if weighting is not None: raise NotImplementedError("weighting other than None") - rms = rms.groupby(["model_lead", "scenario_lead", "model_db", "scenario_db"]).sum() + rms = rms.groupby( + ["model_lead", "scenario_lead", "model_db", "scenario_db"] + ).sum() db_timeseries = self._db.filter(variable=variable_followers).timeseries() db_timeseries = db_timeseries[common_cols] out = [] - for (model, scenario), rms_mod_scen in rms.groupby(["model_lead", "scenario_lead"]): + for (model, scenario), rms_mod_scen in rms.groupby( + ["model_lead", "scenario_lead"] + ): variable_followers_h = set(variable_followers) - for (model_db, scenario_db), _ in rms_mod_scen[(model, scenario)].sort_values().iteritems(): + for (model_db, scenario_db), _ in ( + rms_mod_scen[(model, scenario)].sort_values().iteritems() + ): infill_timeseries = db_timeseries.loc[ (db_timeseries.index.get_level_values("model") == model_db) & (db_timeseries.index.get_level_values("scenario") == scenario_db) - & (db_timeseries.index.get_level_values("variable").isin(variable_followers_h)), - : + & ( + db_timeseries.index.get_level_values("variable").isin( + variable_followers_h + ) + ), + :, ].copy() - variable_followers_h = variable_followers_h - set(infill_timeseries.index.get_level_values("variable")) + variable_followers_h = variable_followers_h - set( + infill_timeseries.index.get_level_values("variable") + ) infill_timeseries = infill_timeseries.reset_index() infill_timeseries.loc[:, "model"] = model diff --git a/tests/integration/crunchers/test_cruncher_rms_closest.py b/tests/integration/crunchers/test_cruncher_rms_closest.py index 4fb155f9..41b6318f 100644 --- a/tests/integration/crunchers/test_cruncher_rms_closest.py +++ b/tests/integration/crunchers/test_cruncher_rms_closest.py @@ -593,17 +593,39 @@ def test_arbitrary_follower(self): _msb + ["World", "Emissions|CH4", "Mt CH4/yr"] + list(model_b_ch4_emms), _msa + ["World", "Emissions|N2O", "Mt N2O/yr"] + list(model_a_n2o_emms), ], - columns=["model", "scenario", "region", "variable", "unit", 2010, 2015, 2020, 2030], + columns=[ + "model", + "scenario", + "region", + "variable", + "unit", + 2010, + 2015, + 2020, + 2030, + ], ) database = IamDataFrame(database) # 2 scenarios to infill to_infill = pd.DataFrame( [ - ["model_c", "scen_a", "World", "Emissions|CO2", "Gt C/yr"] + list(model_a_co2_emms + 0.1), - ["model_d", "scen_a", "World", "Emissions|CO2", "Gt C/yr"] + list(model_b_co2_emms + 0.1), + ["model_c", "scen_a", "World", "Emissions|CO2", "Gt C/yr"] + + list(model_a_co2_emms + 0.1), + ["model_d", "scen_a", "World", "Emissions|CO2", "Gt C/yr"] + + list(model_b_co2_emms + 0.1), + ], + columns=[ + "model", + "scenario", + "region", + "variable", + "unit", + 2010, + 2015, + 2020, + 2030, ], - columns=["model", "scenario", "region", "variable", "unit", 2010, 2015, 2020, 2030], ) to_infill = IamDataFrame(to_infill) @@ -617,9 +639,7 @@ def test_arbitrary_follower(self): one_by_one = concat(one_by_one) - res = cruncher.infill_multiple( - to_infill, to_infill_variables, [lead] - ) + res = cruncher.infill_multiple(to_infill, to_infill_variables, [lead]) assert res.equals(one_by_one)