Skip to content

Commit

Permalink
Merge pull request #246 from CODEX-CELIDA/filter-intervention-criteri…
Browse files Browse the repository at this point in the history
…a-by-population

fix: filter intervention criteria by population
  • Loading branch information
glichtner authored Dec 10, 2024
2 parents a268ee8 + 5f7823a commit d853f0e
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 20 deletions.
34 changes: 32 additions & 2 deletions execution_engine/omop/cohort/population_intervention_pair.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,14 +99,44 @@ def url(self) -> str:
"""
return self._url

@classmethod
def filter_symbols(cls, node: logic.Expr, filter_: logic.Expr) -> logic.Expr:
"""
Filter (=AND-combine) all symbols by the applied filter function
Used to filter all intervention criteria (symbols) by the population output in order to exclude
all intervention events outside the population intervals, which may otherwise interfere with corrected
determination of temporal combination, i.e. the presence of an intervention event during some time window.
"""

if isinstance(node, logic.Symbol):
return logic.And(node, filter_, category=CohortCategory.INTERVENTION)

if hasattr(node, "args") and isinstance(node.args, tuple):
converted_args = [cls.filter_symbols(a, filter_) for a in node.args]

if any(a is not b for a, b in zip(node.args, converted_args)):
node.args = tuple(converted_args)

return node

def execution_graph(self) -> ExecutionGraph:
"""
Get the execution graph for the population/intervention pair.
"""

p = ExecutionGraph.combination_to_expression(self._population)
i = ExecutionGraph.combination_to_expression(self._intervention)

# filter all intervention criteria by the output of the population - this is performed to filter out
# intervention events that outside of the population intervals (i.e. the time windows during which
# patients are part of the population) as otherwise events outside of the population time may be picked up
# by Temporal criteria that determine the presence of some event or condition during a specific time window.
i = self.filter_symbols(i, filter_=p)

pi = logic.LeftDependentToggle(
ExecutionGraph.combination_to_expression(self._population),
ExecutionGraph.combination_to_expression(self._intervention),
p,
i,
category=CohortCategory.POPULATION_INTERVENTION,
)
pi_graph = ExecutionGraph.from_expression(pi, self._base_criterion)
Expand Down
9 changes: 0 additions & 9 deletions execution_engine/omop/cohort/recommendation.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,24 +116,19 @@ def execution_graph(self) -> ExecutionGraph:
"""

p_nodes = []
i_nodes = []
pi_nodes = []
pi_graphs = []

for pi_pair in self._pi_pairs:
pi_graph = pi_pair.execution_graph()

p_nodes.append(pi_graph.sink_node(CohortCategory.POPULATION))
i_nodes.append(pi_graph.sink_node(CohortCategory.INTERVENTION))
pi_nodes.append(pi_graph.sink_node(CohortCategory.POPULATION_INTERVENTION))
pi_graphs.append(pi_graph)

p_combination_node = logic.NoDataPreservingOr(
*p_nodes, category=CohortCategory.POPULATION
)
i_combination_node = logic.NoDataPreservingOr(
*i_nodes, category=CohortCategory.INTERVENTION
)
pi_combination_node = logic.NoDataPreservingAnd(
*pi_nodes, category=CohortCategory.POPULATION_INTERVENTION
)
Expand All @@ -143,9 +138,6 @@ def execution_graph(self) -> ExecutionGraph:
common_graph.add_node(
p_combination_node, store_result=True, category=CohortCategory.POPULATION
)
common_graph.add_node(
i_combination_node, store_result=True, category=CohortCategory.INTERVENTION
)

common_graph.add_node(
pi_combination_node,
Expand All @@ -154,7 +146,6 @@ def execution_graph(self) -> ExecutionGraph:
)

common_graph.add_edges_from((src, p_combination_node) for src in p_nodes)
common_graph.add_edges_from((src, i_combination_node) for src in i_nodes)
common_graph.add_edges_from((src, pi_combination_node) for src in pi_nodes)

return common_graph
Expand Down
15 changes: 13 additions & 2 deletions tests/recommendation/test_recommendation_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -709,6 +709,14 @@ def assemble_daily_recommendation_evaluation(
df, group["intervention"]
)

# filter intervention by population
if df[f"i_{group_name}"].dtype == bool:
df[f"i_{group_name}"] &= df[f"p_{group_name}"]
else:
df[f"i_{group_name}"] = (
df[f"p_{group_name}"] == IntervalType.POSITIVE
) & (df[f"i_{group_name}"] == IntervalType.POSITIVE)

# expressions like "Eq(a+b+c, 1)" (at least one criterion) yield boolean columns and must
# be converted to IntervalType
if df[(f"p_{group_name}", "")].dtype == bool:
Expand Down Expand Up @@ -984,7 +992,11 @@ def process_result(df_result):
df_result_p_i = omopdb.query(
get_query(
partial_day_coverage,
category=[CohortCategory.BASE, CohortCategory.POPULATION],
category=[
CohortCategory.BASE,
CohortCategory.POPULATION,
CohortCategory.INTERVENTION,
],
)
)

Expand All @@ -993,7 +1005,6 @@ def process_result(df_result):
get_query(
full_day_coverage,
category=[
CohortCategory.INTERVENTION,
CohortCategory.POPULATION_INTERVENTION,
],
)
Expand Down
8 changes: 8 additions & 0 deletions tests/recommendation/test_recommendation_base_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,6 +485,14 @@ def assemble_daily_recommendation_evaluation(
df[f"p_{group_name}"] = evaluate_expression(group["population"], df)
df[f"i_{group_name}"] = evaluate_expression(group["intervention"], df)

# filter intervention by population
if df[f"i_{group_name}"].dtype == bool:
df[f"i_{group_name}"] &= df[f"p_{group_name}"]
else:
df[f"i_{group_name}"] = (
df[f"p_{group_name}"] == IntervalType.POSITIVE
) & (df[f"i_{group_name}"] == IntervalType.POSITIVE)

# expressions like "Eq(a+b+c, 1)" (at least one criterion) yield boolean columns and must
# be converted to IntervalType
if df[f"p_{group_name}"].dtype == bool:
Expand Down
12 changes: 5 additions & 7 deletions tests/recommendation/utils/result_comparator.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,16 +42,14 @@ def plan_names(self) -> list[str]:
return [col[2:] for col in self.df.columns if col.startswith("i_")]

def plan_name_column_names(self) -> list[str]:
cols = [
"_".join(i) for i in itertools.product(["p", "i", "p_i"], self.plan_names)
]
return cols + ["p", "i", "p_i"]
cols = ["_".join(i) for i in itertools.product(["p", "p_i"], self.plan_names)]
return cols + ["p", "p_i"]

def derive_database_result(self, df: pd.DataFrame) -> "ResultComparator":
df = df.copy()
df.loc[
:, [c for c in self.plan_name_column_names() if c not in df.columns]
] = False
df.loc[:, [c for c in self.plan_name_column_names() if c not in df.columns]] = (
False
)

return ResultComparator(name="db", df=df)

Expand Down

0 comments on commit d853f0e

Please sign in to comment.