Skip to content

Commit

Permalink
Format
Browse files Browse the repository at this point in the history
  • Loading branch information
znicholls committed Oct 12, 2021
1 parent c4b8b68 commit b254b2b
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 18 deletions.
34 changes: 23 additions & 11 deletions src/silicone/database_crunchers/rms_closest.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,9 @@ def filler(in_iamdf):

return filler

def infill_multiple(self, to_infill, variable_followers, variable_leaders, weighting=None):
def infill_multiple(
self, to_infill, variable_followers, variable_leaders, weighting=None
):
"""
Infill multiple variables simultaneously
Expand Down Expand Up @@ -244,13 +246,11 @@ def infill_multiple(self, to_infill, variable_followers, variable_leaders, weigh

db_lead_ts = db_lead.timeseries()
db_lead_ts.index = db_lead_ts.index.rename(
["model_db", "scenario_db"],
level=["model", "scenario"]
["model_db", "scenario_db"], level=["model", "scenario"]
)
to_infill_lead_ts = to_infill_lead.timeseries()
to_infill_lead_ts.index = to_infill_lead_ts.index.rename(
["model_lead", "scenario_lead"],
level=["model", "scenario"]
["model_lead", "scenario_lead"], level=["model", "scenario"]
)

common_cols = _get_common_cols(db_lead_ts, to_infill_lead_ts)
Expand All @@ -262,24 +262,36 @@ def infill_multiple(self, to_infill, variable_followers, variable_leaders, weigh
rms = ((db_lead_ts - to_infill_lead_ts) ** 2).mean(axis=1) ** 0.5
if weighting is not None:
raise NotImplementedError("weighting other than None")
rms = rms.groupby(["model_lead", "scenario_lead", "model_db", "scenario_db"]).sum()
rms = rms.groupby(
["model_lead", "scenario_lead", "model_db", "scenario_db"]
).sum()

db_timeseries = self._db.filter(variable=variable_followers).timeseries()
db_timeseries = db_timeseries[common_cols]

out = []
for (model, scenario), rms_mod_scen in rms.groupby(["model_lead", "scenario_lead"]):
for (model, scenario), rms_mod_scen in rms.groupby(
["model_lead", "scenario_lead"]
):
variable_followers_h = set(variable_followers)

for (model_db, scenario_db), _ in rms_mod_scen[(model, scenario)].sort_values().iteritems():
for (model_db, scenario_db), _ in (
rms_mod_scen[(model, scenario)].sort_values().iteritems()
):
infill_timeseries = db_timeseries.loc[
(db_timeseries.index.get_level_values("model") == model_db)
& (db_timeseries.index.get_level_values("scenario") == scenario_db)
& (db_timeseries.index.get_level_values("variable").isin(variable_followers_h)),
:
& (
db_timeseries.index.get_level_values("variable").isin(
variable_followers_h
)
),
:,
].copy()

variable_followers_h = variable_followers_h - set(infill_timeseries.index.get_level_values("variable"))
variable_followers_h = variable_followers_h - set(
infill_timeseries.index.get_level_values("variable")
)

infill_timeseries = infill_timeseries.reset_index()
infill_timeseries.loc[:, "model"] = model
Expand Down
34 changes: 27 additions & 7 deletions tests/integration/crunchers/test_cruncher_rms_closest.py
Original file line number Diff line number Diff line change
Expand Up @@ -593,17 +593,39 @@ def test_arbitrary_follower(self):
_msb + ["World", "Emissions|CH4", "Mt CH4/yr"] + list(model_b_ch4_emms),
_msa + ["World", "Emissions|N2O", "Mt N2O/yr"] + list(model_a_n2o_emms),
],
columns=["model", "scenario", "region", "variable", "unit", 2010, 2015, 2020, 2030],
columns=[
"model",
"scenario",
"region",
"variable",
"unit",
2010,
2015,
2020,
2030,
],
)
database = IamDataFrame(database)

# 2 scenarios to infill
to_infill = pd.DataFrame(
[
["model_c", "scen_a", "World", "Emissions|CO2", "Gt C/yr"] + list(model_a_co2_emms + 0.1),
["model_d", "scen_a", "World", "Emissions|CO2", "Gt C/yr"] + list(model_b_co2_emms + 0.1),
["model_c", "scen_a", "World", "Emissions|CO2", "Gt C/yr"]
+ list(model_a_co2_emms + 0.1),
["model_d", "scen_a", "World", "Emissions|CO2", "Gt C/yr"]
+ list(model_b_co2_emms + 0.1),
],
columns=[
"model",
"scenario",
"region",
"variable",
"unit",
2010,
2015,
2020,
2030,
],
columns=["model", "scenario", "region", "variable", "unit", 2010, 2015, 2020, 2030],
)
to_infill = IamDataFrame(to_infill)

Expand All @@ -617,9 +639,7 @@ def test_arbitrary_follower(self):

one_by_one = concat(one_by_one)

res = cruncher.infill_multiple(
to_infill, to_infill_variables, [lead]
)
res = cruncher.infill_multiple(to_infill, to_infill_variables, [lead])

assert res.equals(one_by_one)

Expand Down

0 comments on commit b254b2b

Please sign in to comment.