Skip to content

Commit

Permalink
smtk residuals code refactor fix tests WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
rizac committed Aug 4, 2023
1 parent 0a20b33 commit c4dcfcc
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 17 deletions.
7 changes: 5 additions & 2 deletions egsim/smtk/residuals.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def get_residuals(gsims: Iterable[str], imts: Iterable[str],
return compacted(flatfile, gsims, imts)


def calculate_expected_motions(gsims: Iterable[GMPE], imts: list[str],
def calculate_expected_motions(gsims: Iterable[GMPE], imts: Iterable[str],
flatfile: pd.DataFrame) -> pd.DataFrame:
expected:pd.DataFrame = pd.DataFrame(index=flatfile.index)
for context in yield_event_contexts(flatfile):
Expand Down Expand Up @@ -452,9 +452,12 @@ def compacted(flatfile:pd.DataFrame,
col for col in ['magnitude', 'vs30', 'repi', 'rrup', 'rhypo',
'rjb', 'rx', 'event_depth'] if col in f_cols
]
computed_labels = c_labels.residuals_columns | c_labels.lh_columns | \
{c_labels.total, c_labels.inter_ev, c_labels.intra_ev,
c_labels.mean}
computed_columns = [
col for col, gsim, imtx, lbl in get_computed_columns(gsims, imts, flatfile)
if lbl in c_labels.residuals_columns or lbl in c_labels.lh_columns
if lbl in computed_labels
]
return flatfile[columns + observed_columns + computed_columns]

Expand Down
34 changes: 19 additions & 15 deletions tests/smtk/residuals/test_residuals.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from egsim.smtk import residuals
from egsim.smtk.flatfile import read_flatfile
from egsim.smtk import convert_accel_units

from egsim.smtk.residuals import column_label, c_labels

BASE_DATA_PATH = os.path.join(os.path.dirname(__file__), "data")

Expand Down Expand Up @@ -92,10 +92,23 @@ def test_residuals_execution(self):
with open(os.path.join(BASE_DATA_PATH, file)) as _:
exp_dict = json.load(_)
# check results:
self.assertEqual(len(exp_dict), len(res_dict))
for gsim in res_dict:
self.assertEqual(len(exp_dict[gsim]), len(res_dict[gsim]))
for imt in res_dict[gsim]:
# self.assertEqual(len(exp_dict), len(res_dict))
for gsim in exp_dict:
# self.assertEqual(len(exp_dict[gsim]), len(res_dict[gsim]))
for imt in exp_dict[gsim]:
# check values
values = res_dict[column_label(gsim, imt, c_labels.total_res)]
vals_ok = np.allclose(values, exp_dict[gsim][imt]["Total"])
self.assertTrue(vals_ok)
values = res_dict[column_label(gsim, imt, c_labels.intra_ev_res)]
vals_ok = np.allclose(values, exp_dict[gsim][imt]["Intra event"])
self.assertTrue(vals_ok)
values = res_dict[column_label(gsim, imt, c_labels.inter_ev_res)]
vals_ok = np.allclose(values, exp_dict[gsim][imt]["Inter event"])
self.assertTrue(vals_ok)


# check other stuff:
if gsim == "AkkarEtAlRjb2014":
# For Akkar et al - inter-event residuals should have
# 4 elements and the intra-event residuals 41
Expand All @@ -112,16 +125,7 @@ def test_residuals_execution(self):
len(res_dict[gsim][imt]["Intra event"]), self.num_records)
self.assertEqual(
len(res_dict[gsim][imt]["Total"]), self.num_records)
# check values
values = res_dict[gsim][imt]["Total"]
vals_ok = np.allclose(values, exp_dict[gsim][imt]["Total"])
self.assertTrue(vals_ok)
values = res_dict[gsim][imt]["Inter event"]
vals_ok = np.allclose(values, exp_dict[gsim][imt]["Inter event"])
self.assertTrue(vals_ok)
values = res_dict[gsim][imt]["Intra event"]
vals_ok = np.allclose(values, exp_dict[gsim][imt]["Intra event"])
self.assertTrue(vals_ok)


# residuals.get_residual_statistics()

Expand Down

0 comments on commit c4dcfcc

Please sign in to comment.