From 091207187b8b97a7061e58afa82e3def57530fab Mon Sep 17 00:00:00 2001 From: Stephen Privitera Date: Thu, 23 Jan 2025 13:45:07 +0100 Subject: [PATCH] black --- phenex/filters/__init__.py | 9 +- phenex/filters/categorical_filter.py | 10 +- phenex/ibis_connect.py | 135 ++++++++------ phenex/mappers.py | 159 ++++++++-------- phenex/phenotypes/codelist_phenotype.py | 2 +- phenex/phenotypes/functions.py | 4 +- phenex/tables.py | 176 +++++++++--------- phenex/test/phenotype_test_generator.py | 12 +- .../phenotypes/test_codelist_phenotype.py | 86 +++++---- 9 files changed, 313 insertions(+), 280 deletions(-) diff --git a/phenex/filters/__init__.py b/phenex/filters/__init__.py index 150d260..ee17f8e 100644 --- a/phenex/filters/__init__.py +++ b/phenex/filters/__init__.py @@ -1,2 +1,9 @@ from .relative_time_range_filter import RelativeTimeRangeFilter -from .value import GreaterThan, GreaterThanOrEqualTo, LessThan, LessThanOrEqualTo, EqualTo, Value \ No newline at end of file +from .value import ( + GreaterThan, + GreaterThanOrEqualTo, + LessThan, + LessThanOrEqualTo, + EqualTo, + Value, +) diff --git a/phenex/filters/categorical_filter.py b/phenex/filters/categorical_filter.py index 2595859..e65c519 100644 --- a/phenex/filters/categorical_filter.py +++ b/phenex/filters/categorical_filter.py @@ -30,13 +30,15 @@ def __init__( self.domain = domain super(CategoricalFilter, self).__init__() - def _filter(self, table: 'PhenexTable'): + def _filter(self, table: "PhenexTable"): return table.filter(table[self.column_name].isin(self.allowed_values)) - def autojoin_filter(self, table: 'PhenexTable', tables:dict = None): + def autojoin_filter(self, table: "PhenexTable", tables: dict = None): if self.column_name not in table.columns: if self.domain not in tables.keys(): - raise ValueError(f"Table required for categorical filter ({self.domain}) does not exist within domains dicitonary") - table = table.join(tables[self.domain], domains = tables) + raise ValueError( + f"Table required for categorical filter ({self.domain}) does not exist within domains dicitonary" + ) + table = table.join(tables[self.domain], domains=tables) # TODO downselect to original columns return table.filter(table[self.column_name].isin(self.allowed_values)) diff --git a/phenex/ibis_connect.py b/phenex/ibis_connect.py index 3b3f472..c6d2de7 100644 --- a/phenex/ibis_connect.py +++ b/phenex/ibis_connect.py @@ -40,28 +40,29 @@ class SnowflakeConnector: Methods: connect_dest() -> BaseBackend: Establishes and returns an Ibis backend connection to the destination Snowflake database and schema. - + connect_source() -> BaseBackend: Establishes and returns an Ibis backend connection to the source Snowflake database and schema. - + get_source_table(name_table: str) -> Table: Retrieves a table from the source Snowflake database. - + get_dest_table(name_table: str) -> Table: Retrieves a table from the destination Snowflake database. - + create_view(table: Table, name_table: Optional[str] = None, overwrite: bool = False) -> View: Create a view of a table in the destination Snowflake database. - + create_table(table: Table, name_table: Optional[str] = None, overwrite: bool = False) -> Table: Materialize a table in the destination Snowflake database. - + drop_table(name_table: str) -> None: Drop a table from the destination Snowflake database. - + drop_view(name_table: str) -> None: Drop a view from the destination Snowflake database. """ + def __init__( self, SNOWFLAKE_USER: Optional[str] = None, @@ -73,22 +74,36 @@ def __init__( SNOWFLAKE_DEST_DATABASE: Optional[str] = None, ): self.SNOWFLAKE_USER = SNOWFLAKE_USER or os.environ.get("SNOWFLAKE_USER") - self.SNOWFLAKE_ACCOUNT = SNOWFLAKE_ACCOUNT or os.environ.get("SNOWFLAKE_ACCOUNT") - self.SNOWFLAKE_WAREHOUSE = SNOWFLAKE_WAREHOUSE or os.environ.get("SNOWFLAKE_WAREHOUSE") + self.SNOWFLAKE_ACCOUNT = SNOWFLAKE_ACCOUNT or os.environ.get( + "SNOWFLAKE_ACCOUNT" + ) + self.SNOWFLAKE_WAREHOUSE = SNOWFLAKE_WAREHOUSE or os.environ.get( + "SNOWFLAKE_WAREHOUSE" + ) self.SNOWFLAKE_ROLE = SNOWFLAKE_ROLE or os.environ.get("SNOWFLAKE_ROLE") - self.SNOWFLAKE_PASSWORD = SNOWFLAKE_PASSWORD or os.environ.get("SNOWFLAKE_PASSWORD") - self.SNOWFLAKE_SOURCE_DATABASE = SNOWFLAKE_SOURCE_DATABASE or os.environ.get("SNOWFLAKE_SOURCE_DATABASE") - self.SNOWFLAKE_DEST_DATABASE = SNOWFLAKE_DEST_DATABASE or os.environ.get("SNOWFLAKE_DEST_DATABASE") + self.SNOWFLAKE_PASSWORD = SNOWFLAKE_PASSWORD or os.environ.get( + "SNOWFLAKE_PASSWORD" + ) + self.SNOWFLAKE_SOURCE_DATABASE = SNOWFLAKE_SOURCE_DATABASE or os.environ.get( + "SNOWFLAKE_SOURCE_DATABASE" + ) + self.SNOWFLAKE_DEST_DATABASE = SNOWFLAKE_DEST_DATABASE or os.environ.get( + "SNOWFLAKE_DEST_DATABASE" + ) try: - _, _ = self.SNOWFLAKE_SOURCE_DATABASE.split('.') + _, _ = self.SNOWFLAKE_SOURCE_DATABASE.split(".") except: - raise ValueError('Use a fully qualified database name (e.g. CATALOG.DATABASE).') + raise ValueError( + "Use a fully qualified database name (e.g. CATALOG.DATABASE)." + ) try: - _, _ = self.SNOWFLAKE_SOURCE_DATABASE.split('.') + _, _ = self.SNOWFLAKE_SOURCE_DATABASE.split(".") except: - raise ValueError('Use a fully qualified database name (e.g. CATALOG.DATABASE).') - + raise ValueError( + "Use a fully qualified database name (e.g. CATALOG.DATABASE)." + ) + required_vars = [ "SNOWFLAKE_USER", "SNOWFLAKE_ACCOUNT", @@ -102,22 +117,25 @@ def __init__( self.source_connection = self.connect_source() self.dest_connection = self.connect_dest() - def _check_env_vars(self, required_vars: List[str]): for var in required_vars: if not getattr(self, var): - raise ValueError(f"Missing required variable: {var}. Set in the environment or pass through __init__().") + raise ValueError( + f"Missing required variable: {var}. Set in the environment or pass through __init__()." + ) def _check_source_dest(self): - if (self.SNOWFLAKE_SOURCE_DATABASE == self.SNOWFLAKE_DEST_DATABASE and - self.SNOWFLAKE_SOURCE_SCHEMA == self.SNOWFLAKE_DEST_SCHEMA): + if ( + self.SNOWFLAKE_SOURCE_DATABASE == self.SNOWFLAKE_DEST_DATABASE + and self.SNOWFLAKE_SOURCE_SCHEMA == self.SNOWFLAKE_DEST_SCHEMA + ): raise ValueError("Source and destination locations cannot be the same.") def _connect(self, database) -> BaseBackend: - ''' + """ Private method to get a database connection. End users should use connect_source() and connect_dest() to get connections to source and destination databases. - ''' - database, schema = database.split('.') + """ + database, schema = database.split(".") # # In Ibis speak: catalog = collection of databases # database = collection of tables @@ -127,7 +145,7 @@ def _connect(self, database) -> BaseBackend: # # In the below connect method, the arguments are the SNOWFLAKE terms. # - + if self.SNOWFLAKE_PASSWORD: return ibis.snowflake.connect( user=self.SNOWFLAKE_USER, @@ -136,7 +154,7 @@ def _connect(self, database) -> BaseBackend: warehouse=self.SNOWFLAKE_WAREHOUSE, role=self.SNOWFLAKE_ROLE, database=database, - schema=schema + schema=schema, ) else: return ibis.snowflake.connect( @@ -146,9 +164,9 @@ def _connect(self, database) -> BaseBackend: warehouse=self.SNOWFLAKE_WAREHOUSE, role=self.SNOWFLAKE_ROLE, database=database, - schema=schema + schema=schema, ) - + def connect_dest(self) -> BaseBackend: """ Establishes and returns an Ibis backend connection to the destination Snowflake database. @@ -157,8 +175,8 @@ def connect_dest(self) -> BaseBackend: BaseBackend: Ibis backend connection to the destination Snowflake database. """ return self._connect( - database=self.SNOWFLAKE_DEST_DATABASE, - ) + database=self.SNOWFLAKE_DEST_DATABASE, + ) def connect_source(self) -> BaseBackend: """ @@ -168,8 +186,8 @@ def connect_source(self) -> BaseBackend: BaseBackend: Ibis backend connection to the source Snowflake database. """ return self._connect( - database=self.SNOWFLAKE_SOURCE_DATABASE, - ) + database=self.SNOWFLAKE_SOURCE_DATABASE, + ) def get_source_table(self, name_table): """ @@ -182,8 +200,7 @@ def get_source_table(self, name_table): Table: Ibis table object from the source Snowflake database. """ return self.dest_connection.table( - name_table, - database=self.SNOWFLAKE_SOURCE_DATABASE + name_table, database=self.SNOWFLAKE_SOURCE_DATABASE ) def get_dest_table(self, name_table): @@ -197,16 +214,16 @@ def get_dest_table(self, name_table): Table: Ibis table object from the destination Snowflake database. """ return self.dest_connection.table( - name_table, - database=self.SNOWFLAKE_DEST_DATABASE + name_table, database=self.SNOWFLAKE_DEST_DATABASE ) + def _get_output_table_name(self, table): if table.has_name: - name_table = table.get_name().split('.')[-1] + name_table = table.get_name().split(".")[-1] else: - raise ValueError('Must specify name_table!') + raise ValueError("Must specify name_table!") return name_table - + def create_view(self, table, name_table=None, overwrite=False): """ Create a view of a table in the destination Snowflake database. @@ -220,17 +237,14 @@ def create_view(self, table, name_table=None, overwrite=False): View: Ibis view object created in the destination Snowflake database. """ name_table = name_table or self._get_output_table_name(table) - + # Check if the destination database exists, if not, create it - catalog, database = self.SNOWFLAKE_DEST_DATABASE.split('.') + catalog, database = self.SNOWFLAKE_DEST_DATABASE.split(".") if not database in self.dest_connection.list_databases(): self.dest_connection.create_database(name=database, catalog=catalog) - + return self.dest_connection.create_view( - name=name_table, - database=database, - obj=table, - overwrite=overwrite + name=name_table, database=database, obj=table, overwrite=overwrite ) def create_table(self, table, name_table=None, overwrite=False): @@ -246,17 +260,14 @@ def create_table(self, table, name_table=None, overwrite=False): Table: Ibis table object created in the destination Snowflake database. """ name_table = name_table or self._get_output_table_name(table) - + # Check if the destination database exists, if not, create it - catalog, database = self.SNOWFLAKE_DEST_DATABASE.split('.') + catalog, database = self.SNOWFLAKE_DEST_DATABASE.split(".") if not database in self.dest_connection.list_databases(): self.dest_connection.create_database(name=database, catalog=catalog) - + return self.dest_connection.create_table( - name=name_table, - database=database, - obj=table, - overwrite=overwrite + name=name_table, database=database, obj=table, overwrite=overwrite ) def drop_table(self, name_table): @@ -270,10 +281,9 @@ def drop_table(self, name_table): None """ return self.dest_connection.drop_table( - name=name_table, - database=self.SNOWFLAKE_DEST_DATABASE - ) - + name=name_table, database=self.SNOWFLAKE_DEST_DATABASE + ) + def drop_view(self, name_table): """ Drop a view from the destination Snowflake database. @@ -285,9 +295,9 @@ def drop_view(self, name_table): None """ return self.dest_connection.drop_view( - name=name_table, - database=self.SNOWFLAKE_DEST_DATABASE - ) + name=name_table, database=self.SNOWFLAKE_DEST_DATABASE + ) + class DuckDBConnector: """ @@ -300,6 +310,7 @@ class DuckDBConnector: connect() -> BaseBackend: Establishes and returns an Ibis backend connection to the DuckDB database. """ + def __init__(self, DUCKDB_PATH: Optional[str] = ":memory"): """ Initializes the DuckDBConnector with the specified path. @@ -318,4 +329,6 @@ def connect(self) -> BaseBackend: """ required_vars = ["DUCKDB_PATH"] _check_env_vars(*required_vars) - return ibis.connect(backend="duckdb", path=self.DUCKDB_PATH or os.getenv("DUCKDB_PATH")) + return ibis.connect( + backend="duckdb", path=self.DUCKDB_PATH or os.getenv("DUCKDB_PATH") + ) diff --git a/phenex/mappers.py b/phenex/mappers.py index baf6d2b..bc71ab7 100644 --- a/phenex/mappers.py +++ b/phenex/mappers.py @@ -34,16 +34,16 @@ def set_mapped_tables(self, con, overwrite=False) -> Dict[str, Table]: Returns: Dict[str, Table]: A dictionary where keys are domain names and values are mapped tables. """ - existing_tables = con.dest_connection.list_tables(database=con.SNOWFLAKE_DEST_DATABASE) + existing_tables = con.dest_connection.list_tables( + database=con.SNOWFLAKE_DEST_DATABASE + ) for domain, mapper in self.domains_dict.items(): if domain not in existing_tables or overwrite: t = con.get_source_table(mapper.NAME_TABLE) mapped_table = mapper(t).table con.create_view( - mapped_table, - name_table=mapper.NAME_TABLE, - overwrite=overwrite - ) + mapped_table, name_table=mapper.NAME_TABLE, overwrite=overwrite + ) def get_mapped_tables(self, con) -> Dict[str, PhenexTable]: """ @@ -69,134 +69,135 @@ def get_mapped_tables(self, con) -> Dict[str, PhenexTable]: # OMOP Column Mappers # class OMOPPersonTable(PhenexPersonTable): - NAME_TABLE = 'PERSON' - DEFAULT_MAPPING = { - 'PERSON_ID': "PERSON_ID", - 'DATE_OF_BIRTH': 'BIRTH_DATETIME' - } + NAME_TABLE = "PERSON" + DEFAULT_MAPPING = {"PERSON_ID": "PERSON_ID", "DATE_OF_BIRTH": "BIRTH_DATETIME"} JOIN_KEYS = { - 'OMOPConditionOccurenceTable': ['PERSON_ID'], - 'OMOPVisitDetailTable': ['PERSON_ID'] + "OMOPConditionOccurenceTable": ["PERSON_ID"], + "OMOPVisitDetailTable": ["PERSON_ID"], } + class OMOPVisitDetailTable(PhenexVisitDetailTable): - NAME_TABLE = 'VISIT_DETAIL' + NAME_TABLE = "VISIT_DETAIL" JOIN_KEYS = { - 'OMOPPersonTable': ['PERSON_ID'], - 'OMOPConditionOccurenceTable': ['PERSON_ID', 'VISIT_DETAIL_ID'] + "OMOPPersonTable": ["PERSON_ID"], + "OMOPConditionOccurenceTable": ["PERSON_ID", "VISIT_DETAIL_ID"], } + class OMOPConditionOccurenceTable(CodeTable): - NAME_TABLE = 'CONDITION_OCCURRENCE' + NAME_TABLE = "CONDITION_OCCURRENCE" JOIN_KEYS = { - 'OMOPPersonTable': ['PERSON_ID'], - 'OMOPVisitDetailTable': ['PERSON_ID', 'VISIT_DETAIL_ID'] # I changed this from EVENT_DATE + "OMOPPersonTable": ["PERSON_ID"], + "OMOPVisitDetailTable": [ + "PERSON_ID", + "VISIT_DETAIL_ID", + ], # I changed this from EVENT_DATE } DEFAULT_MAPPING = { - 'PERSON_ID': "PERSON_ID", - 'EVENT_DATE': "CONDITION_START_DATE", - 'CODE': "CONDITION_CONCEPT_ID", + "PERSON_ID": "PERSON_ID", + "EVENT_DATE": "CONDITION_START_DATE", + "CODE": "CONDITION_CONCEPT_ID", } + class OMOPDeathTable(PhenexTable): - NAME_TABLE = 'DEATH' - JOIN_KEYS = { - 'OMOPPersonTable': ['PERSON_ID'] - } - KNOWN_FIELDS = [ - 'PERSON_ID', - 'DATE_OF_DEATH' - ] - DEFAULT_MAPPING = { - 'PERSON_ID': "PERSON_ID", - 'DATE_OF_DEATH': "DEATH_DATE" - } + NAME_TABLE = "DEATH" + JOIN_KEYS = {"OMOPPersonTable": ["PERSON_ID"]} + KNOWN_FIELDS = ["PERSON_ID", "DATE_OF_DEATH"] + DEFAULT_MAPPING = {"PERSON_ID": "PERSON_ID", "DATE_OF_DEATH": "DEATH_DATE"} + class OMOPProcedureOccurrenceTable(CodeTable): - NAME_TABLE = 'PROCEDURE_OCCURRENCE' + NAME_TABLE = "PROCEDURE_OCCURRENCE" JOIN_KEYS = { - 'OMOPPersonTable': ['PERSON_ID'], - 'OMOPVisitDetailTable': ['PERSON_ID', 'VISIT_DETAIL_ID'] + "OMOPPersonTable": ["PERSON_ID"], + "OMOPVisitDetailTable": ["PERSON_ID", "VISIT_DETAIL_ID"], } DEFAULT_MAPPING = { - 'PERSON_ID': "PERSON_ID", - 'EVENT_DATE': "PROCEDURE_DATE", - 'CODE': "PROCEDURE_CONCEPT_ID", + "PERSON_ID": "PERSON_ID", + "EVENT_DATE": "PROCEDURE_DATE", + "CODE": "PROCEDURE_CONCEPT_ID", } + class OMOPDrugExposureTable(CodeTable): - NAME_TABLE = 'DRUG_EXPOSURE' + NAME_TABLE = "DRUG_EXPOSURE" JOIN_KEYS = { - 'OMOPPersonTable': ['PERSON_ID'], - 'OMOPVisitDetailTable': ['PERSON_ID', 'VISIT_DETAIL_ID'] + "OMOPPersonTable": ["PERSON_ID"], + "OMOPVisitDetailTable": ["PERSON_ID", "VISIT_DETAIL_ID"], } DEFAULT_MAPPING = { - 'PERSON_ID': "PERSON_ID", - 'EVENT_DATE': "DRUG_EXPOSURE_START_DATE", - 'CODE': "DRUG_CONCEPT_ID", + "PERSON_ID": "PERSON_ID", + "EVENT_DATE": "DRUG_EXPOSURE_START_DATE", + "CODE": "DRUG_CONCEPT_ID", } + class OMOPConditionOccurrenceSourceTable(CodeTable): - NAME_TABLE = 'CONDITION_OCCURRENCE' + NAME_TABLE = "CONDITION_OCCURRENCE" JOIN_KEYS = { - 'OMOPPersonTable': ['PERSON_ID'], - 'OMOPVisitDetailTable': ['PERSON_ID', 'VISIT_DETAIL_ID'] + "OMOPPersonTable": ["PERSON_ID"], + "OMOPVisitDetailTable": ["PERSON_ID", "VISIT_DETAIL_ID"], } DEFAULT_MAPPING = { - 'PERSON_ID': "PERSON_ID", - 'EVENT_DATE': "CONDITION_START_DATE", - 'CODE': "CONDITION_SOURCE_VALUE", + "PERSON_ID": "PERSON_ID", + "EVENT_DATE": "CONDITION_START_DATE", + "CODE": "CONDITION_SOURCE_VALUE", } + class OMOPProcedureOccurrenceSourceTable(CodeTable): - NAME_TABLE = 'PROCEDURE_OCCURRENCE' + NAME_TABLE = "PROCEDURE_OCCURRENCE" JOIN_KEYS = { - 'OMOPPersonTable': ['PERSON_ID'], - 'OMOPVisitDetailTable': ['PERSON_ID', 'VISIT_DETAIL_ID'] + "OMOPPersonTable": ["PERSON_ID"], + "OMOPVisitDetailTable": ["PERSON_ID", "VISIT_DETAIL_ID"], } DEFAULT_MAPPING = { - 'PERSON_ID': "PERSON_ID", - 'EVENT_DATE': "PROCEDURE_DATE", - 'CODE': "PROCEDURE_SOURCE_VALUE", + "PERSON_ID": "PERSON_ID", + "EVENT_DATE": "PROCEDURE_DATE", + "CODE": "PROCEDURE_SOURCE_VALUE", } + class OMOPDrugExposureSourceTable(CodeTable): - NAME_TABLE = 'DRUG_EXPOSURE' + NAME_TABLE = "DRUG_EXPOSURE" JOIN_KEYS = { - 'OMOPPersonTable': ['PERSON_ID'], - 'OMOPVisitDetailTable': ['PERSON_ID', 'VISIT_DETAIL_ID'] + "OMOPPersonTable": ["PERSON_ID"], + "OMOPVisitDetailTable": ["PERSON_ID", "VISIT_DETAIL_ID"], } DEFAULT_MAPPING = { - 'PERSON_ID': "PERSON_ID", - 'EVENT_DATE': "DRUG_EXPOSURE_START_DATE", - 'CODE': "DRUG_SOURCE_VALUE", + "PERSON_ID": "PERSON_ID", + "EVENT_DATE": "DRUG_EXPOSURE_START_DATE", + "CODE": "DRUG_SOURCE_VALUE", } + class OMOPPersonTableSource(PhenexPersonTable): - NAME_TABLE = 'PERSON' + NAME_TABLE = "PERSON" JOIN_KEYS = { - 'OMOPConditionOccurenceTable': ['PERSON_ID'], - 'OMOPVisitDetailTable': ['PERSON_ID'] + "OMOPConditionOccurenceTable": ["PERSON_ID"], + "OMOPVisitDetailTable": ["PERSON_ID"], } DEFAULT_MAPPING = { - 'PERSON_ID': "PERSON_ID", - 'DATE_OF_BIRTH': "BIRTH_DATETIME", - 'YEAR_OF_BIRTH': "YEAR_OF_BIRTH", - 'SEX': "GENDER_SOURCE_VALUE", - 'ETHNICITY': "ETHNICITY_SOURCE_VALUE", + "PERSON_ID": "PERSON_ID", + "DATE_OF_BIRTH": "BIRTH_DATETIME", + "YEAR_OF_BIRTH": "YEAR_OF_BIRTH", + "SEX": "GENDER_SOURCE_VALUE", + "ETHNICITY": "ETHNICITY_SOURCE_VALUE", } + class OMOPObservationPeriodTable(PhenexObservationPeriodTable): - NAME_TABLE = 'OBSERVATION_PERIOD' - JOIN_KEYS = { - 'OMOPPersonTable': ['PERSON_ID'] - } + NAME_TABLE = "OBSERVATION_PERIOD" + JOIN_KEYS = {"OMOPPersonTable": ["PERSON_ID"]} DEFAULT_MAPPING = { - 'PERSON_ID': "PERSON_ID", - 'OBSERVATION_PERIOD_START_DATE': "OBSERVATION_PERIOD_START_DATE", - 'OBSERVATION_PERIOD_END_DATE': 'OBSERVATION_PERIOD_END_DATE' + "PERSON_ID": "PERSON_ID", + "OBSERVATION_PERIOD_START_DATE": "OBSERVATION_PERIOD_START_DATE", + "OBSERVATION_PERIOD_END_DATE": "OBSERVATION_PERIOD_END_DATE", } + # # Domains # @@ -213,4 +214,4 @@ class OMOPObservationPeriodTable(PhenexObservationPeriodTable): "PERSON_SOURCE": OMOPPersonTableSource, "OBSERVATION_PERIOD": OMOPObservationPeriodTable, } -OMOPDomains = DomainsDictionary(OMOPs) \ No newline at end of file +OMOPDomains = DomainsDictionary(OMOPs) diff --git a/phenex/phenotypes/codelist_phenotype.py b/phenex/phenotypes/codelist_phenotype.py index e63f640..c2f8945 100644 --- a/phenex/phenotypes/codelist_phenotype.py +++ b/phenex/phenotypes/codelist_phenotype.py @@ -69,7 +69,7 @@ def __init__( RelativeTimeRangeFilter, List[RelativeTimeRangeFilter] ] = None, return_date="first", - categorical_filter: 'CategoricalFilter' = None + categorical_filter: "CategoricalFilter" = None, ): super(CodelistPhenotype, self).__init__() diff --git a/phenex/phenotypes/functions.py b/phenex/phenotypes/functions.py index b209e6e..6bc90ea 100644 --- a/phenex/phenotypes/functions.py +++ b/phenex/phenotypes/functions.py @@ -15,11 +15,11 @@ def hstack(phenotypes: List["Phenotype"], join_table: Table = None) -> Table: if isinstance(join_table, PhenexTable): join_table = join_table.table idx_phenotype_to_begin = 0 - join_type = "left" # if join table is defined, we want to left join + join_type = "left" # if join table is defined, we want to left join if join_table is None: idx_phenotype_to_begin = 1 join_table = phenotypes[0].namespaced_table - join_type = "outer" # if join table is NOT defined, we want an outer join + join_type = "outer" # if join table is NOT defined, we want an outer join for pt in phenotypes[idx_phenotype_to_begin:]: join_table = join_table.join(pt.namespaced_table, "PERSON_ID", how=join_type) join_table = join_table.mutate( diff --git a/phenex/tables.py b/phenex/tables.py index f9976da..70f2ba8 100644 --- a/phenex/tables.py +++ b/phenex/tables.py @@ -4,6 +4,7 @@ import ibis import copy + class PhenexTable: """ Phenex provides certain table types on which it knows how to operate. For instance, Phenex implements a CodeTable, which is an event table containing codes. Phenex has abstracted operations for each table type. For instance, given a CodeTable, Phenex knows how to filter this table based on the presence of codes within that table. Phenex doesn't care if the code table is actually a diagnosis code table or a procedure code table or a medication code table. @@ -16,20 +17,23 @@ class PhenexTable: Note that for each table type, there are REQUIRED_FIELDS, i.e., fields that MUST be defined for Phenex to work with such a table and KNOWN_FIELDS, i.e., fields that Phenex internally understands what to do with (there is a Phenotype that knows how to work with that field). For instance, in a PhenexPersonTable, one MUST define PERSON_ID, but DATE_OF_BIRTH is an optional field that PhenEx can process if given and transform into AGE. These are fixed for each table type and should not be overridden. """ - NAME_TABLE = 'PHENEX_TABLE' # name of table in the database - JOIN_KEYS = {} # dict: class name -> List[phenex column names] - KNOWN_FIELDS = [] # List[phenex column names] - DEFAULT_MAPPING = {} # dict: input column name -> phenex column name - PATHS = {} # dict: table class name -> List[other table class names] + + NAME_TABLE = "PHENEX_TABLE" # name of table in the database + JOIN_KEYS = {} # dict: class name -> List[phenex column names] + KNOWN_FIELDS = [] # List[phenex column names] + DEFAULT_MAPPING = {} # dict: input column name -> phenex column name + PATHS = {} # dict: table class name -> List[other table class names] REQUIRED_FIELDS = list(DEFAULT_MAPPING.keys()) def __init__(self, table, name=None, column_mapping={}): - ''' + """ Instantiate a PhenexTable, possibly overriding NAME_TABLE and COLUMN_MAPPING. - ''' + """ if not isinstance(table, Table): - raise TypeError(f"Cannot instantiatiate {self.__class__.__name__} from {type(table)}. Must be ibis Table.") + raise TypeError( + f"Cannot instantiatiate {self.__class__.__name__} from {type(table)}. Must be ibis Table." + ) self.NAME_TABLE = name or self.NAME_TABLE @@ -45,13 +49,15 @@ def __init__(self, table, name=None, column_mapping={}): self._add_phenotype_table_relationship() def _add_phenotype_table_relationship(self): - self.JOIN_KEYS['PhenotypeTable'] = ['PERSON_ID'] + self.JOIN_KEYS["PhenotypeTable"] = ["PERSON_ID"] def _get_column_mapping(self, column_mapping=None): column_mapping = column_mapping or {} for key in column_mapping.keys(): if key not in self.KNOWN_FIELDS: - raise ValueError(f"Unknown mapped field {key} --> {column_mapping[key]} for f{type(self)}.") + raise ValueError( + f"Unknown mapped field {key} --> {column_mapping[key]} for f{type(self)}." + ) default_mapping = copy.deepcopy(self.DEFAULT_MAPPING) default_mapping.update(column_mapping) return default_mapping @@ -67,7 +73,7 @@ def __getitem__(self, key): def table(self): return self._table - def join(self, other: 'PhenexTable', *args, domains = None, **kwargs): + def join(self, other: "PhenexTable", *args, domains=None, **kwargs): """ The join method performs a join of PhenexTables, using autojoin functionality if Phenex is able to find the table types specified in PATHS. """ @@ -83,21 +89,30 @@ def join(self, other: 'PhenexTable', *args, domains = None, **kwargs): # Do an autojoin by finding a path from the left to the right table and sequentially joining as necessary joined_table = current_table = self for next_table_class_name in self._find_path(other): - next_table = [v for k,v in domains.items() if v.__class__.__name__ == next_table_class_name] + next_table = [ + v + for k, v in domains.items() + if v.__class__.__name__ == next_table_class_name + ] if len(next_table) != 1: - raise ValueError(f"Unable to find unqiue {next_table_class_name} required to join {other.__class__.__name__}") + raise ValueError( + f"Unable to find unqiue {next_table_class_name} required to join {other.__class__.__name__}" + ) next_table = next_table[0] - print(f"Joining : {current_table.__class__.__name__} to {next_table.__class__.__name__}") + print( + f"Joining : {current_table.__class__.__name__} to {next_table.__class__.__name__}" + ) # join keys are defined by the left table; in theory should enforce symmetry join_keys = current_table.JOIN_KEYS[next_table_class_name] columns = list(set(joined_table.columns + next_table.columns)) - joined_table = joined_table.join(next_table, join_keys, **kwargs).select(columns) + joined_table = joined_table.join(next_table, join_keys, **kwargs).select( + columns + ) current_table = next_table return joined_table - def _find_path(self, other): start_name = self.__class__.__name__ @@ -110,15 +125,17 @@ def _find_path(self, other): try: return self.PATHS[end_name] + [end_name] except KeyError: - raise ValueError(f'Cannot autojoin {start_name} --> {end_name}. Please specify join path in PATHS.') + raise ValueError( + f"Cannot autojoin {start_name} --> {end_name}. Please specify join path in PATHS." + ) def select(self, *args, **kwargs): return type(self)(self.table.select(*args, **kwargs), name=self.NAME_TABLE) def filter(self, expr): - ''' + """ Filter the table by an Ibis Expression or using a PhenExFilter. - ''' + """ input_columns = self.columns if isinstance(expr, ibis.expr.types.Expr) or isinstance(expr, list): filtered_table = self.table.filter(expr) @@ -128,127 +145,108 @@ def filter(self, expr): return type(self)( filtered_table.select(input_columns), name=self.NAME_TABLE, - column_mapping=self.column_mapping) + column_mapping=self.column_mapping, + ) class PhenexPersonTable(PhenexTable): - NAME_TABLE = 'PERSON' - JOIN_KEYS = { - 'CodeTable': ['PERSON_ID'], - 'PhenexVisitDetailTable': ['PERSON_ID'] - } + NAME_TABLE = "PERSON" + JOIN_KEYS = {"CodeTable": ["PERSON_ID"], "PhenexVisitDetailTable": ["PERSON_ID"]} KNOWN_FIELDS = [ - 'PERSON_ID', - 'DATE_OF_BIRTH', - 'YEAR_OF_BIRTH', - 'DATE_OF_DEATH', - 'SEX', - 'ETHNICITY' + "PERSON_ID", + "DATE_OF_BIRTH", + "YEAR_OF_BIRTH", + "DATE_OF_DEATH", + "SEX", + "ETHNICITY", ] - DEFAULT_MAPPING = { - 'PERSON_ID': "PERSON_ID" - } + DEFAULT_MAPPING = {"PERSON_ID": "PERSON_ID"} + class EventTable(PhenexTable): - NAME_TABLE = 'EVENT' - KNOWN_FIELDS = [ - 'PERSON_ID', - 'EVENT_DATE' - ] + NAME_TABLE = "EVENT" + KNOWN_FIELDS = ["PERSON_ID", "EVENT_DATE"] DEFAULT_MAPPING = { - 'PERSON_ID': "PERSON_ID", - 'EVENT_DATE': "EVENT_DATE", + "PERSON_ID": "PERSON_ID", + "EVENT_DATE": "EVENT_DATE", } + class CodeTable(PhenexTable): - NAME_TABLE = 'CODE' + NAME_TABLE = "CODE" RELATIONSHIPS = { - 'PhenexPersonTable': ['PERSON_ID'], - 'PhenexVisitDetailTable': ['PERSON_ID', 'VISIT_DETAIL_ID'], + "PhenexPersonTable": ["PERSON_ID"], + "PhenexVisitDetailTable": ["PERSON_ID", "VISIT_DETAIL_ID"], } - KNOWN_FIELDS = [ - 'PERSON_ID', - 'EVENT_DATE', - 'CODE', - 'CODE_TYPE', - 'VISIT_DETAIL_ID' - ] + KNOWN_FIELDS = ["PERSON_ID", "EVENT_DATE", "CODE", "CODE_TYPE", "VISIT_DETAIL_ID"] DEFAULT_MAPPING = { - 'PERSON_ID': "PERSON_ID", - 'EVENT_DATE': "EVENT_DATE", - 'CODE': "CODE", + "PERSON_ID": "PERSON_ID", + "EVENT_DATE": "EVENT_DATE", + "CODE": "CODE", } + class PhenexVisitDetailTable(PhenexTable): - NAME_TABLE = 'VISIT_DETAIL' + NAME_TABLE = "VISIT_DETAIL" RELATIONSHIPS = { - 'PhenexPersonTable': ['PERSON_ID'], - 'CodeTable': ['PERSON_ID', 'VISIT_DETAIL_ID'], + "PhenexPersonTable": ["PERSON_ID"], + "CodeTable": ["PERSON_ID", "VISIT_DETAIL_ID"], } KNOWN_FIELDS = [ - 'PERSON_ID', - 'VISIT_DETAIL_ID', - 'VISIT_DETAIL_SOURCE_VALUE', + "PERSON_ID", + "VISIT_DETAIL_ID", + "VISIT_DETAIL_SOURCE_VALUE", ] DEFAULT_MAPPING = { - 'PERSON_ID':'PERSON_ID', - 'VISIT_DETAIL_ID': 'VISIT_DETAIL_ID', - 'VISIT_DETAIL_SOURCE_VALUE': 'VISIT_DETAIL_SOURCE_VALUE' + "PERSON_ID": "PERSON_ID", + "VISIT_DETAIL_ID": "VISIT_DETAIL_ID", + "VISIT_DETAIL_SOURCE_VALUE": "VISIT_DETAIL_SOURCE_VALUE", } class PhenexIndexTable(PhenexTable): - NAME_TABLE = 'INDEX' - JOIN_KEYS = { - 'CodeTable': ['PERSON_ID'], - 'PhenexVisitDetailTable': ['PERSON_ID'] - } + NAME_TABLE = "INDEX" + JOIN_KEYS = {"CodeTable": ["PERSON_ID"], "PhenexVisitDetailTable": ["PERSON_ID"]} KNOWN_FIELDS = [ - 'PERSON_ID', - 'INDEX_DATE', + "PERSON_ID", + "INDEX_DATE", ] - DEFAULT_MAPPING = { - 'PERSON_ID': "PERSON_ID", - "INDEX_DATE": "INDEX_DATE" - } + DEFAULT_MAPPING = {"PERSON_ID": "PERSON_ID", "INDEX_DATE": "INDEX_DATE"} + class PhenexObservationPeriodTable(PhenexTable): - NAME_TABLE = 'OBSERVATION_PERIOD' + NAME_TABLE = "OBSERVATION_PERIOD" KNOWN_FIELDS = [ - 'PERSON_ID', - 'OBSERVATION_PERIOD_START_DATE', - 'OBSERVATION_PERIOD_END_DATE' + "PERSON_ID", + "OBSERVATION_PERIOD_START_DATE", + "OBSERVATION_PERIOD_END_DATE", ] DEFAULT_MAPPING = { - 'PERSON_ID': "PERSON_ID", + "PERSON_ID": "PERSON_ID", "OBSERVATION_PERIOD_START_DATE": "OBSERVATION_PERIOD_START_DATE", - "OBSERVATION_PERIOD_END_DATE": 'OBSERVATION_PERIOD_END_DATE' + "OBSERVATION_PERIOD_END_DATE": "OBSERVATION_PERIOD_END_DATE", } + class MeasurementTable(Table): # These datatpyes are just used for type hinting pass class PhenotypeTable(PhenexTable): - NAME_TABLE = 'PHENOTYPE' - KNOWN_FIELDS = [ - 'PERSON_ID', - 'BOOLEAN', - 'EVENT_DATE', - 'VALUE' - ] + NAME_TABLE = "PHENOTYPE" + KNOWN_FIELDS = ["PERSON_ID", "BOOLEAN", "EVENT_DATE", "VALUE"] DEFAULT_MAPPING = { - 'PERSON_ID': "PERSON_ID", + "PERSON_ID": "PERSON_ID", "BOOLEAN": "BOOLEAN", "EVENT_DATE": "EVENT_DATE", - "VALUE": "VALUE" + "VALUE": "VALUE", } diff --git a/phenex/test/phenotype_test_generator.py b/phenex/test/phenotype_test_generator.py index ecc5295..13ca03b 100644 --- a/phenex/test/phenotype_test_generator.py +++ b/phenex/test/phenotype_test_generator.py @@ -11,6 +11,7 @@ import logging from phenex.tables import * + class PhenotypeTestGenerator: """ This class is a base class for all TestGenerators. @@ -84,13 +85,16 @@ def _create_input_data(self): table = self.con.create_table( input_info["name"], input_info["df"], schema=schema ) - if 'type' in input_info.keys(): - table_type = input_info['type'] + if "type" in input_info.keys(): + table_type = input_info["type"] else: table_type = PhenexTable - if input_info['name'].lower() in ['condition_occurrence', 'measurement']: + if input_info["name"].lower() in [ + "condition_occurrence", + "measurement", + ]: table_type = CodeTable - elif input_info['name'].lower() in ['person']: + elif input_info["name"].lower() in ["person"]: table_type = PhenexPersonTable self.domains[input_info["name"]] = table_type(table) diff --git a/phenex/test/phenotypes/test_codelist_phenotype.py b/phenex/test/phenotypes/test_codelist_phenotype.py index a4dbb2b..79c2ccc 100644 --- a/phenex/test/phenotypes/test_codelist_phenotype.py +++ b/phenex/test/phenotypes/test_codelist_phenotype.py @@ -528,42 +528,45 @@ def define_phenotype_tests(self): from phenex.tables import CodeTable, PhenexTable + class DummyConditionOccurenceTable(CodeTable): - NAME_TABLE = 'DIAGNOSIS' + NAME_TABLE = "DIAGNOSIS" JOIN_KEYS = { - 'DummyPersonTable': ['PERSON_ID'], - 'DummyEncounterTable': ['PERSON_ID', 'ENCID'] # I changed this from EVENT_DATE - } - PATHS = { - 'DummyVisitDetailTable': ['DummyEncounterTable'] + "DummyPersonTable": ["PERSON_ID"], + "DummyEncounterTable": ["PERSON_ID", "ENCID"], # I changed this from EVENT_DATE } + PATHS = {"DummyVisitDetailTable": ["DummyEncounterTable"]} DEFAULT_MAPPING = { - 'PERSON_ID': "PERSON_ID", - 'EVENT_DATE': "DATE", - 'CODE': "CODE", - "CODE_TYPE":"CODE_TYPE" + "PERSON_ID": "PERSON_ID", + "EVENT_DATE": "DATE", + "CODE": "CODE", + "CODE_TYPE": "CODE_TYPE", } + class DummyEncounterTable(PhenexTable): - NAME_TABLE = 'ENCOUNTER' + NAME_TABLE = "ENCOUNTER" JOIN_KEYS = { - 'DummyPersonTable': ['PERSON_ID'], - 'DummyConditionOccurenceTable': ['PERSON_ID', 'ENCID'], # I changed this from EVENT_DATE - 'DummyVisitDetailTable': ['PERSON_ID', 'VISITID'] # I changed this from EVENT_DATE - } - DEFAULT_MAPPING = { - 'PERSON_ID': "PERSON_ID" + "DummyPersonTable": ["PERSON_ID"], + "DummyConditionOccurenceTable": [ + "PERSON_ID", + "ENCID", + ], # I changed this from EVENT_DATE + "DummyVisitDetailTable": [ + "PERSON_ID", + "VISITID", + ], # I changed this from EVENT_DATE } + DEFAULT_MAPPING = {"PERSON_ID": "PERSON_ID"} + class DummyVisitDetailTable(PhenexTable): - NAME_TABLE = 'VISIT' + NAME_TABLE = "VISIT" JOIN_KEYS = { - 'DummyPersonTable': ['PERSON_ID'], - 'DummyEncounterTable': ['PERSON_ID', 'VISITID'] - } - DEFAULT_MAPPING = { - 'PERSON_ID': "PERSON_ID" + "DummyPersonTable": ["PERSON_ID"], + "DummyEncounterTable": ["PERSON_ID", "VISITID"], } + DEFAULT_MAPPING = {"PERSON_ID": "PERSON_ID"} class CodelistPhenotypeCategoricalFilterTestGenerator(PhenotypeTestGenerator): @@ -579,21 +582,24 @@ def define_input_tables(self): df["DATE"] = datetime.datetime.strptime("10-10-2021", "%m-%d-%Y") df2 = pd.DataFrame() - df2['PERSON_ID'] = [f"P{i}" for i in range(N)] + df2["PERSON_ID"] = [f"P{i}" for i in range(N)] df2["ENCID"] = [f"E{i}" for i in range(N)] - df2['VISITID'] = [f"V{i}" for i in range(N)] - df2['flag1'] = ['a']*2 + ['b']*2 + ['c']*(N-2-2) + df2["VISITID"] = [f"V{i}" for i in range(N)] + df2["flag1"] = ["a"] * 2 + ["b"] * 2 + ["c"] * (N - 2 - 2) df3 = pd.DataFrame() - df3['PERSON_ID'] = [f"P{i}" for i in range(N)] - df3['VISITID'] = [f"V{i}" for i in range(N)] - df3['flag2'] = ['d']*5 + ['e']*3 + ['f']*(N-5-3) - + df3["PERSON_ID"] = [f"P{i}" for i in range(N)] + df3["VISITID"] = [f"V{i}" for i in range(N)] + df3["flag2"] = ["d"] * 5 + ["e"] * 3 + ["f"] * (N - 5 - 3) return [ - {"name": "condition_occurrence", "df": df, "type":DummyConditionOccurenceTable}, - {"name": "encounter", "df":df2, "type":DummyEncounterTable}, - {"name": "visit", "df":df3, "type":DummyVisitDetailTable} + { + "name": "condition_occurrence", + "df": df, + "type": DummyConditionOccurenceTable, + }, + {"name": "encounter", "df": df2, "type": DummyEncounterTable}, + {"name": "visit", "df": df3, "type": DummyVisitDetailTable}, ] def define_phenotype_tests(self): @@ -608,19 +614,19 @@ def define_phenotype_tests(self): codelist=codelist_factory.get_codelist("c1"), domain="condition_occurrence", categorical_filter=CategoricalFilter( - allowed_values=["a"], column_name="flag1", domain = 'encounter' + allowed_values=["a"], column_name="flag1", domain="encounter" ), ), } c1b = { "name": "single_flag_direct_join_b", - "persons": [f"P{i}" for i in range(2,4)], + "persons": [f"P{i}" for i in range(2, 4)], "phenotype": CodelistPhenotype( codelist=codelist_factory.get_codelist("c1"), domain="condition_occurrence", categorical_filter=CategoricalFilter( - allowed_values=["b"], column_name="flag1", domain = 'encounter' + allowed_values=["b"], column_name="flag1", domain="encounter" ), ), } @@ -632,19 +638,19 @@ def define_phenotype_tests(self): codelist=codelist_factory.get_codelist("c1"), domain="condition_occurrence", categorical_filter=CategoricalFilter( - allowed_values=["d"], column_name="flag2", domain = 'visit' + allowed_values=["d"], column_name="flag2", domain="visit" ), ), } c2b = { "name": "single_flag_intermediary_join_b", - "persons": [f"P{i}" for i in range(5,8)], + "persons": [f"P{i}" for i in range(5, 8)], "phenotype": CodelistPhenotype( codelist=codelist_factory.get_codelist("c1"), domain="condition_occurrence", categorical_filter=CategoricalFilter( - allowed_values=["e"], column_name="flag2", domain = 'visit' + allowed_values=["e"], column_name="flag2", domain="visit" ), ), } @@ -655,6 +661,7 @@ def define_phenotype_tests(self): return test_infos + # class CodelistPhenotypeCategoricalFilterLogicalCombinationsTestGenerator(PhenotypeTestGenerator): # name_space = "clpt_categorical_filter_logic" @@ -769,6 +776,7 @@ def test_categorical_filter_phenotype(): tg = CodelistPhenotypeCategoricalFilterTestGenerator() tg.run_tests() + if __name__ == "__main__": test_categorical_filter_phenotype() test_relative_time_range_filter()