diff --git a/adtl/__init__.py b/adtl/__init__.py index ded6a7f..4b922d2 100644 --- a/adtl/__init__.py +++ b/adtl/__init__.py @@ -269,7 +269,7 @@ def get_combined_type(row: StrDict, rule: StrDict, ctx: Context = None): return next( filter( lambda item: item is not None, - [get_value(row, r, ctx) for r in rules], + flatten([get_value(row, r, ctx) for r in rules]), ) ) except StopIteration: @@ -283,11 +283,11 @@ def get_combined_type(row: StrDict, rule: StrDict, ctx: Context = None): "excludeWhen rule should be 'none', 'false-like', or a list of values" ) - values = [get_value(row, r, ctx) for r in rules] + values = flatten([get_value(row, r, ctx) for r in rules]) if combined_type == "set": values = [*set(values)] if excludeWhen is None: - return values + return list(values) if excludeWhen == "none": return [v for v in values if v is not None] elif excludeWhen == "false-like": @@ -298,6 +298,19 @@ def get_combined_type(row: StrDict, rule: StrDict, ctx: Context = None): raise ValueError(f"Unknown {combined_type} in {rule}") +def flatten(xs): + """ + Flatten a list of lists +-/ non-list items + e.g. + [None, ['Dexamethasone']] -> [None, 'Dexamethasome'] + """ + for x in xs: + if isinstance(x, Iterable) and not isinstance(x, (str, bytes)): + yield from flatten(x) + else: + yield x + + def expand_refs(spec_fragment: StrDict, defs: StrDict) -> Union[StrDict, List[StrDict]]: "Expand all references (ref) with definitions (defs)" diff --git a/tests/test_parser.py b/tests/test_parser.py index da55049..1333aad 100644 --- a/tests/test_parser.py +++ b/tests/test_parser.py @@ -1232,3 +1232,60 @@ def test_main_save_report(): "validation_errors": {}, } Path("epoch-report.json").unlink() + + +@pytest.mark.parametrize( + "test_input,expected", + [ + ( + [None, ["Dexamethasone", "Fluticasone", "Methylprednisolone"]], + [None, "Dexamethasone", "Fluticasone", "Methylprednisolone"], + ), + ([12, ["13", "14"], [[15], ["sixteen"]]], [12, "13", "14", 15, "sixteen"]), + ], +) +def test_flatten(test_input, expected): + assert list(parser.flatten(test_input)) == expected + + +@pytest.mark.parametrize( + "test_row, test_combination, expected", + [ + ( + {"corticost": "", "corticost_v2": "Dexa"}, + "set", + [None, "Dexamethasone"], + ), + ({"corticost": "Decadron", "corticost_v2": "Dexa"}, "set", ["Dexamethasone"]), + ( + {"corticost": "", "corticost_v2": "Cortisonal"}, + "firstNonNull", + "Cortisonal", + ), + ], +) +def test_combinedtype_wordsubstituteset(test_row, test_combination, expected): + test_rule = { + "combinedType": test_combination, + "fields": [ + { + "field": "corticost", + "apply": { + "function": "wordSubstituteSet", + "params": [ + ["Metil?corten", "Prednisone"], + ["Decadron", "Dexamethasone"], + ], + }, + }, + { + "field": "corticost_v2", + "apply": { + "function": "wordSubstituteSet", + "params": [["Cortisonal", "Cortisonal"], ["Dexa", "Dexamethasone"]], + }, + }, + ], + } + + assert parser.get_combined_type(test_row, test_rule) == expected