From 6a51c0a11ff67abad0e879f18c3d1e03d91ee174 Mon Sep 17 00:00:00 2001 From: Tom Jemmett Date: Wed, 1 Oct 2025 12:58:01 +0100 Subject: [PATCH 1/2] adds script to generate synthetic data largely follows what is in the nhp_model repo's notebook, but adds this to the data extraction pipeline. rather than extracting to the dev folder, it extracts to synth --- .../nhp_data-extract_nhp_for_containers.yaml | 19 +- pyproject.toml | 1 + .../model_data/generate_synthetic_data.py | 204 ++++++++++++++++++ 3 files changed, 223 insertions(+), 1 deletion(-) create mode 100644 src/nhp/data/model_data/generate_synthetic_data.py diff --git a/databricks_workflows/nhp_data-extract_nhp_for_containers.yaml b/databricks_workflows/nhp_data-extract_nhp_for_containers.yaml index 4b644e8..3e0ce49 100644 --- a/databricks_workflows/nhp_data-extract_nhp_for_containers.yaml +++ b/databricks_workflows/nhp_data-extract_nhp_for_containers.yaml @@ -143,9 +143,26 @@ resources: - pypi: package: pygam==0.9.1 - whl: ../dist/*.whl - - task_key: clean_up + - task_key: generate_synthetic_data depends_on: - task_key: generate_national_gams + for_each_task: + inputs: "{{job.parameters.years}}" + task: + task_key: run_generate_synthetic_data + python_wheel_task: + package_name: nhp_data + entry_point: model_data-generate_synthetic_data + parameters: + - "{{job.parameters.save_path}}" + - "{{input}}" + - "20251001" + job_cluster_key: run_nhp_extracts_cluster + libraries: + - whl: ../dist/*.whl + - task_key: clean_up + depends_on: + - task_key: generate_synthetic_data python_wheel_task: package_name: nhp_data entry_point: model_data-clean_up diff --git a/pyproject.toml b/pyproject.toml index 96139ff..5719105 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -70,6 +70,7 @@ model_data-inequalities = "nhp.data.model_data.inequalities:main" model_data-ip = "nhp.data.model_data.ip:main" model_data-op = "nhp.data.model_data.op:main" model_data-clean_up = "nhp.data.model_data.clean_up:main" +model_data-generate_synthetic_data = "nhp.data.model_data.generate_synthetic_data:main" reference-day_procedures = "nhp.data.reference.day_procedures:main" reference-ods_trusts = "nhp.data.reference.ods_trusts:main" diff --git a/src/nhp/data/model_data/generate_synthetic_data.py b/src/nhp/data/model_data/generate_synthetic_data.py new file mode 100644 index 0000000..966e9b2 --- /dev/null +++ b/src/nhp/data/model_data/generate_synthetic_data.py @@ -0,0 +1,204 @@ +import logging +import os +import sys +import uuid +from typing import Callable + +import numpy as np +import pandas as pd +import pyspark.sql.functions as F +from pyspark.sql import DataFrame, SparkSession + +from nhp.data.get_spark import get_spark + +logger = logging.getLogger(__name__) + + +def generate_data(name: str): + def decorator(func: Callable[["SynthData", DataFrame], pd.DataFrame]): + def wrapper(self): + logger.info(f"Generating synthetic data for {name}") + df = self.read_dev_file(name) + result = func(self, df) + self.save_synth_file(name, result) + logger.info(f"Synthetic data for {name} saved") + + return wrapper + + return decorator + + +class SynthData: + # how many inpatients rows should we target? + IP_N = 100000 + + def __init__(self, fyear: int, path: str, seed: int, spark: SparkSession): + self._fyear = fyear + self._dev_path = f"{path}/dev" + self._synth_path = f"{path}/synth" + self._seed = seed + + self._spark = spark + + # helper methods + + def read_dev_file(self, file: str) -> pd.DataFrame: + return self.read_file(file, self._dev_path) + + def read_synth_file(self, file: str) -> pd.DataFrame: + return self.read_file(file, self._synth_path) + + def read_file(self, file: str, path: str) -> pd.DataFrame: + return ( + self._spark.read.parquet(f"{path}/{file}") + .filter(F.col("fyear") == self._fyear) + .drop("fyear") + ) + + def save_synth_file(self, file: str, df: pd.DataFrame) -> None: + p = f"{self._synth_path}/{file}/fyear={self._fyear}/dataset=synthetic" + os.makedirs(p, exist_ok=True) + df.to_parquet(f"{p}/0.parquet") + + def generate(self) -> None: + self._ip() + self._ip_activity_avoidance_stratgegies() + self._ip_efficiencies_strategies() + self._inequalities() + self._aae() + self._op() + self._birth_factors() + self._demographic_factors() + self._hsa_activity_tables() + + # synth methods + + @generate_data("ip") + def _ip(self, df: DataFrame) -> pd.DataFrame: + ip_R = self.IP_N / df.count() + + df = df.sample(False, ip_R, self._seed) + + ip = df.drop("dataset", "fyear").toPandas() + ip = ip.assign(sitetret=np.random.choice(["a", "b", "c"], len(ip))) + + hrgs = list(ip["sushrg_trimmed"].value_counts()[:2].index).copy() + ip["sushrg_trimmed"] = ip["sushrg_trimmed"].replace(hrgs, ["HRG1", "HRG2"]) + + return ip + + @generate_data("ip_activity_avoidance_strategies") + def _ip_activity_avoidance_stratgegies(self, df: DataFrame) -> pd.DataFrame: + ip_df = self.read_synth_file("ip") + return df.join(ip_df, "rn", "semi").toPandas() + + @generate_data("ip_efficiencies_strategies") + def _ip_efficiencies_strategies(self, df: DataFrame) -> pd.DataFrame: + ip_df = self.read_synth_file("ip") + return df.join(ip_df, "rn", "semi").toPandas() + + @generate_data("inequalities") + def _inequalities(self, df: DataFrame) -> pd.DataFrame: + inequalities = df.drop("dataset").toPandas() + # TODO: sort hrgs + hrgs = [] + inequalities["sushrg_trimmed"] = inequalities["sushrg_trimmed"].replace( + hrgs, ["HRG1", "HRG2"] + ) + return inequalities.drop_duplicates( + subset=["sushrg_trimmed", "icb", "imd_quintile"] + ) + + @generate_data("aae") + def _aae(self, df: DataFrame) -> pd.DataFrame: + rng = np.random.default_rng(self._seed) + n_aae_datasets = df.select("dataset").distinct().count() + + df = df.drop("index", "dataset").withColumn("sitetret", F.lit("a")) + + aae = ( + df.groupBy(df.drop("arrivals").columns) + .agg(F.sum("arrivals").alias("arrivals")) + .toPandas() + .assign(arrivals=lambda r: rng.poisson(r["arrivals"] / n_aae_datasets)) + .query("arrivals > 0") + ) + + aae["rn"] = [str(uuid.uuid4()) for _ in aae.index] + + return aae + + @generate_data("op") + def _op(self, df: DataFrame) -> pd.DataFrame: + rng = np.random.default_rng(self._seed) + n_op_datasets = df.select("dataset").distinct().count() + + df = df.drop("index", "dataset").withColumn("sitetret", F.lit("a")) + + op = ( + df.groupBy(df.drop("attendances", "tele_attendances").columns) + .agg( + F.sum("attendances").alias("attendances"), + F.sum("tele_attendances").alias("tele_attendances"), + ) + .toPandas() + .assign( + attendances=lambda r: rng.poisson(r["attendances"] / n_op_datasets), + tele_attendances=lambda r: rng.poisson( + r["tele_attendances"] / n_op_datasets + ), + ) + .query("(attendances > 0) or (tele_attendances > 0)") + ) + + op["rn"] = [str(uuid.uuid4()) for _ in op.index] + + return op + + @generate_data("birth_factors") + def _birth_factors(self, df: DataFrame) -> pd.DataFrame: + return ( + df.drop("dataset") + .filter(~F.col("variant").startswith("custom_projection_")) + .toPandas() + .groupby(["variant", "sex", "age"], as_index=False) + .mean() + ) + + @generate_data("demographic_factors") + def _demographic_factors(self, df: DataFrame) -> pd.DataFrame: + return ( + df.drop("dataset") + .filter(~F.col("variant").startswith("custom_projection_")) + .toPandas() + .groupby(["variant", "sex", "age"], as_index=False) + .mean() + ) + + @generate_data("hsa_activity_tables") + def _hsa_activity_tables(self, df: DataFrame) -> pd.DataFrame: + return ( + df.drop("dataset") + .toPandas() + .groupby(["hsagrp", "sex", "age"], as_index=False) + .mean() + ) + + +def main(): + logging.basicConfig(level=logging.INFO) + logger.setLevel(logging.INFO) + + handler = logging.StreamHandler() + handler.setLevel(logging.INFO) + + logging.getLogger("py4j").setLevel(logging.ERROR) + + path = sys.argv[1] + fyear = int(sys.argv[2][:4]) + seed = int(sys.argv[3]) + + spark = get_spark("model_data") + + d = SynthData(fyear, path, seed, spark) + d.generate() From fb03af1352f71cd38741095392501ebe35002986 Mon Sep 17 00:00:00 2001 From: Tom Jemmett Date: Thu, 2 Oct 2025 09:48:56 +0100 Subject: [PATCH 2/2] better handles hrg remapping --- .../model_data/generate_synthetic_data.py | 29 +++++++++++++++---- 1 file changed, 23 insertions(+), 6 deletions(-) diff --git a/src/nhp/data/model_data/generate_synthetic_data.py b/src/nhp/data/model_data/generate_synthetic_data.py index 966e9b2..a895ba5 100644 --- a/src/nhp/data/model_data/generate_synthetic_data.py +++ b/src/nhp/data/model_data/generate_synthetic_data.py @@ -71,6 +71,25 @@ def generate(self) -> None: self._demographic_factors() self._hsa_activity_tables() + @property + def hrgs(self) -> list: + if not hasattr(self, "_hrgs"): + ip_df = ( + self.read_dev_file("ip") + .groupBy("sushrg_trimmed") + .count() + .orderBy(F.desc("count")) + .collect() + ) + self._hrgs = [row["sushrg_trimmed"] for row in ip_df] + return self._hrgs + + def _hrg_remapping(self, col: pd.Series) -> pd.Series: + hrgs = self.hrgs + if not hrgs: + raise ValueError("HRGs list is empty. Cannot remap.") + return col.replace(hrgs, [f"HRG{i + 1}" for i, _ in enumerate(hrgs)]) + # synth methods @generate_data("ip") @@ -82,8 +101,7 @@ def _ip(self, df: DataFrame) -> pd.DataFrame: ip = df.drop("dataset", "fyear").toPandas() ip = ip.assign(sitetret=np.random.choice(["a", "b", "c"], len(ip))) - hrgs = list(ip["sushrg_trimmed"].value_counts()[:2].index).copy() - ip["sushrg_trimmed"] = ip["sushrg_trimmed"].replace(hrgs, ["HRG1", "HRG2"]) + ip["sushrg_trimmed"] = self._hrg_remapping(ip["sushrg_trimmed"]) return ip @@ -100,10 +118,8 @@ def _ip_efficiencies_strategies(self, df: DataFrame) -> pd.DataFrame: @generate_data("inequalities") def _inequalities(self, df: DataFrame) -> pd.DataFrame: inequalities = df.drop("dataset").toPandas() - # TODO: sort hrgs - hrgs = [] - inequalities["sushrg_trimmed"] = inequalities["sushrg_trimmed"].replace( - hrgs, ["HRG1", "HRG2"] + inequalities["sushrg_trimmed"] = self._hrg_remapping( + inequalities["sushrg_trimmed"] ) return inequalities.drop_duplicates( subset=["sushrg_trimmed", "icb", "imd_quintile"] @@ -150,6 +166,7 @@ def _op(self, df: DataFrame) -> pd.DataFrame: ) .query("(attendances > 0) or (tele_attendances > 0)") ) + op["sushrg_trimmed"] = self._hrg_remapping(op["sushrg_trimmed"]) op["rn"] = [str(uuid.uuid4()) for _ in op.index]