diff --git a/src/gentropy/common/schemas.py b/src/gentropy/common/schemas.py index 1dcd75a22..624e3e0e1 100644 --- a/src/gentropy/common/schemas.py +++ b/src/gentropy/common/schemas.py @@ -1,66 +1,212 @@ """Methods for handling schemas.""" + from __future__ import annotations import importlib.resources as pkg_resources import json -from collections import namedtuple -from typing import Any +from collections import defaultdict -import pyspark.sql.types as t +from pyspark.sql.types import ArrayType, StructType from gentropy.assets import schemas -def parse_spark_schema(schema_json: str) -> t.StructType: +class SchemaValidationError(Exception): + """This exception is raised when a schema validation fails.""" + + def __init__( + self: SchemaValidationError, message: str, errors: defaultdict[str, list[str]] + ) -> None: + """Initialize the SchemaValidationError. + + Args: + message (str): The message to be displayed. + errors (defaultdict[str, list[str]]): The collection of observed discrepancies + """ + super().__init__(message) + self.message = message # Explicitly set the message attribute + self.errors = errors + + def __str__(self: SchemaValidationError) -> str: + """Return a string representation of the exception. + + Returns: + str: The string representation of the exception. + """ + stringified_errors = "\n ".join( + [f'{k}: {",".join(v)}' for k, v in self.errors.items()] + ) + return f"{self.message}\nErrors:\n {stringified_errors}" + + +def parse_spark_schema(schema_json: str) -> StructType: """Parse Spark schema from JSON. Args: schema_json (str): JSON filename containing spark schema in the schemas package Returns: - t.StructType: Spark schema + StructType: Spark schema """ core_schema = json.loads( pkg_resources.read_text(schemas, schema_json, encoding="utf-8") ) - return t.StructType.fromJson(core_schema) + return StructType.fromJson(core_schema) -def flatten_schema(schema: t.StructType, prefix: str = "") -> list[Any]: - """It takes a Spark schema and returns a list of all fields in the schema once flattened. +def compare_array_schemas( + observed_schema: ArrayType, + expected_schema: ArrayType, + parent_field_name: str | None = None, + schema_issues: defaultdict[str, list[str]] | None = None, +) -> defaultdict[str, list[str]]: + """Compare two array schemas. + + The comparison is done recursively, so nested structs are also compared. Args: - schema (t.StructType): The schema of the dataframe - prefix (str): The prefix to prepend to the field names. Defaults to "". + observed_schema (ArrayType): The observed schema. + expected_schema (ArrayType): The expected schema. + parent_field_name (str | None): The parent field name. Defaults to None. + schema_issues (defaultdict[str, list[str]] | None): The schema issues. Defaults to None. Returns: - list[Any]: A list of all the columns in the dataframe. - - Examples: - >>> from pyspark.sql.types import ArrayType, StringType, StructField, StructType - >>> schema = StructType( - ... [ - ... StructField("studyLocusId", StringType(), False), - ... StructField("locus", ArrayType(StructType([StructField("variantId", StringType(), False)])), False) - ... ] - ... ) - >>> df = spark.createDataFrame([("A", [{"variantId": "varA"}]), ("B", [{"variantId": "varB"}])], schema) - >>> flatten_schema(df.schema) - [Field(name='studyLocusId', dataType=StringType()), Field(name='locus', dataType=ArrayType(StructType([]), True)), Field(name='locus.variantId', dataType=StringType())] + defaultdict[str, list[str]]: The schema issues. """ - Field = namedtuple("Field", ["name", "dataType"]) - fields = [] - for field in schema.fields: - name = f"{prefix}.{field.name}" if prefix else field.name - dtype = field.dataType - if isinstance(dtype, t.StructType): - fields.append(Field(name, t.ArrayType(t.StructType()))) - fields += flatten_schema(dtype, prefix=name) - elif isinstance(dtype, t.ArrayType) and isinstance( - dtype.elementType, t.StructType - ): - fields.append(Field(name, t.ArrayType(t.StructType()))) - fields += flatten_schema(dtype.elementType, prefix=name) - else: - fields.append(Field(name, dtype)) - return fields + # Create default values if not provided: + if schema_issues is None: + schema_issues = defaultdict(list) + + if parent_field_name is None: + parent_field_name = "" + + observed_type = observed_schema.elementType.typeName() + expected_type = expected_schema.elementType.typeName() + + # If element types are not matching, no further tests are needed: + if observed_type != expected_type: + schema_issues["columns_with_non_matching_type"].append( + f'For column "{parent_field_name}[]" found {observed_type} instead of {expected_type}' + ) + + # If element type is a struct, resolve nesting: + elif (observed_type == "struct") and (expected_type == "struct"): + schema_issues = compare_struct_schemas( + observed_schema.elementType, + expected_schema.elementType, + f"{parent_field_name}[].", + schema_issues, + ) + + # If element type is an array, resolve nesting: + elif (observed_type == "array") and (expected_type == "array"): + schema_issues = compare_array_schemas( + observed_schema.elementType, + expected_schema.elementType, + parent_field_name, + schema_issues, + ) + + return schema_issues + + +def compare_struct_schemas( + observed_schema: StructType, + expected_schema: StructType, + parent_field_name: str | None = None, + schema_issues: defaultdict[str, list[str]] | None = None, +) -> defaultdict[str, list[str]]: + """Compare two struct schemas. + + The comparison is done recursively, so nested structs are also compared. + + Checking logic: + 1. Checking for duplicated columns in the observed schema. + 2. Checking for missing mandatory columns in the observed schema. + 3. Now we know that all mandatory columns are present, we can iterate over the observed schema and compare the types. + 4. Flagging unexpected columns in the observed schema. + 5. Flagging columns with non-matching types. + 6. If a column is a struct -> call compare_struct_schemas + 7. If a column is an array -> call compare_array_schemas + 8. Return dictionary with issues. + + Args: + observed_schema (StructType): The observed schema. + expected_schema (StructType): The expected schema. + parent_field_name (str | None): The parent field name. Defaults to None. + schema_issues (defaultdict[str, list[str]] | None): The schema issues. Defaults to None. + + Returns: + defaultdict[str, list[str]]: The schema issues. + """ + # Create default values if not provided: + if schema_issues is None: + schema_issues = defaultdict(list) + + if parent_field_name is None: + parent_field_name = "" + + # Flagging duplicated columns if present: + if duplicated_columns := list( + { + f"{parent_field_name}{field.name}" + for field in observed_schema + if list(observed_schema).count(field) > 1 + } + ): + schema_issues["duplicated_columns"] += duplicated_columns + + # Testing mandatory fields: + required_fields = [x.name for x in expected_schema if not x.nullable] + if missing_required_fields := [ + f"{parent_field_name}{req}" + for req in required_fields + if not any(field.name == req for field in observed_schema) + ]: + schema_issues["missing_mandatory_columns"] += missing_required_fields + + # Converting schema to dictionaries for easier comparison: + observed_schema_dict = {field.name: field for field in observed_schema} + expected_schema_dict = {field.name: field for field in expected_schema} + + # Testing optional fields and types: + for field_name, field in observed_schema_dict.items(): + # Testing observed field name, if name is not matched, no further tests are needed: + if field_name not in expected_schema_dict: + schema_issues["unexpected_columns"].append( + f"{parent_field_name}{field_name}" + ) + continue + + # When we made sure the field is in both schemas, extracting field type information: + observed_type = field.dataType + observed_type_name = field.dataType.typeName() + + expected_type = expected_schema_dict[field_name].dataType + expected_type_name = expected_schema_dict[field_name].dataType.typeName() + + # Flagging non-matching types if types don't match, jumping to next field: + if observed_type_name != expected_type_name: + schema_issues["columns_with_non_matching_type"].append( + f'For column "{parent_field_name}{field_name}" found {observed_type_name} instead of {expected_type_name}' + ) + continue + + # If column is a struct, resolve nesting: + if observed_type_name == "struct": + schema_issues = compare_struct_schemas( + observed_type, + expected_type, + f"{parent_field_name}{field_name}.", + schema_issues, + ) + # If column is an array, resolve nesting: + elif observed_type_name == "array": + schema_issues = compare_array_schemas( + observed_type, + expected_type, + f"{parent_field_name}{field_name}[]", + schema_issues, + ) + + return schema_issues diff --git a/src/gentropy/common/spark_helpers.py b/src/gentropy/common/spark_helpers.py index 680975ef6..4d24212a2 100644 --- a/src/gentropy/common/spark_helpers.py +++ b/src/gentropy/common/spark_helpers.py @@ -270,13 +270,13 @@ def neglog_pvalue_to_mantissa_and_exponent(p_value: Column) -> tuple[Column, Col +--------+--------------+--------------+ |negLogPv|pValueMantissa|pValueExponent| +--------+--------------+--------------+ - | 4.56| 3.6307805| -5| - | 2109.23| 1.6982436| -2110| + | 4.56| 2.7542286| -5| + | 2109.23| 5.8884363| -2110| +--------+--------------+--------------+ """ exponent: Column = f.ceil(p_value) - mantissa: Column = f.pow(f.lit(10), (p_value - exponent + f.lit(1))) + mantissa: Column = f.pow(f.lit(10), (exponent - p_value)) return ( mantissa.cast(t.FloatType()).alias("pValueMantissa"), @@ -614,14 +614,21 @@ def rename_all_columns(df: DataFrame, prefix: str) -> DataFrame: ) -def safe_array_union(a: Column, b: Column) -> Column: +def safe_array_union( + a: Column, b: Column, fields_order: list[str] | None = None +) -> Column: """Merge the content of two optional columns. - The function assumes the array columns have the same schema. Otherwise, the function will fail. + The function assumes the array columns have the same schema. + If the `fields_order` is passed, the function assumes that it deals with array of structs and sorts the nested + struct fields by the provided `fields_order` before conducting array_merge. + If the `fields_order` is not passed and both columns are > type then function assumes struct fields have the same order, + otherwise the function will raise an AnalysisException. Args: a (Column): One optional array column. b (Column): The other optional array column. + fields_order (list[str] | None): The order of the fields in the struct. Defaults to None. Returns: Column: array column with merged content. @@ -644,12 +651,88 @@ def safe_array_union(a: Column, b: Column) -> Column: | null| +------+ + >>> schema="arr2: array>, arr: array>" + >>> data = [([(1,"a",), (2, "c")],[("a", 1,)]),] + >>> df = spark.createDataFrame(data=data, schema=schema) + >>> df.select(safe_array_union(f.col("arr"), f.col("arr2"), fields_order=["a", "b"]).alias("merged")).show() + +----------------+ + | merged| + +----------------+ + |[{a, 1}, {c, 2}]| + +----------------+ + + >>> schema="arr2: array>, arr: array>" + >>> data = [([(1,"a",), (2, "c")],[("a", 1,)]),] + >>> df = spark.createDataFrame(data=data, schema=schema) + >>> df.select(safe_array_union(f.col("arr"), f.col("arr2")).alias("merged")).show() # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + pyspark.sql.utils.AnalysisException: ... """ + if fields_order: + # sort the nested struct fields by the provided order + a = sort_array_struct_by_columns(a, fields_order) + b = sort_array_struct_by_columns(b, fields_order) return f.when(a.isNotNull() & b.isNotNull(), f.array_union(a, b)).otherwise( f.coalesce(a, b) ) +def sort_array_struct_by_columns(column: Column, fields_order: list[str]) -> Column: + """Sort nested struct fields by provided fields order. + + Args: + column (Column): Column with array of structs. + fields_order (list[str]): List of field names to sort by. + + Returns: + Column: Sorted column. + + Examples: + >>> schema="arr: array>" + >>> data = [([(1,"a",), (2, "c")],)] + >>> fields_order = ["a", "b"] + >>> df = spark.createDataFrame(data=data, schema=schema) + >>> df.select(sort_array_struct_by_columns(f.col("arr"), fields_order).alias("sorted")).show() + +----------------+ + | sorted| + +----------------+ + |[{c, 2}, {a, 1}]| + +----------------+ + + """ + column_name = extract_column_name(column) + fields_order_expr = ", ".join([f"x.{field}" for field in fields_order]) + return f.expr( + f"sort_array(transform({column_name}, x -> struct({fields_order_expr})), False)" + ).alias(column_name) + + +def extract_column_name(column: Column) -> str: + """Extract column name from a column expression. + + Args: + column (Column): Column expression. + + Returns: + str: Column name. + + Raises: + ValueError: If the column name cannot be extracted. + + Examples: + >>> extract_column_name(f.col('col1')) + 'col1' + >>> extract_column_name(f.sort_array(f.col('col1'))) + 'sort_array(col1, true)' + """ + pattern = re.compile("^Column<'(?P.*)'>?") + + _match = pattern.search(str(column)) + if not _match: + raise ValueError(f"Cannot extract column name from {column}") + return _match.group("name") + + def create_empty_column_if_not_exists( col_name: str, col_schema: t.DataType = t.NullType() ) -> Column: diff --git a/src/gentropy/config.py b/src/gentropy/config.py index 3a67e7868..0a1f9438a 100644 --- a/src/gentropy/config.py +++ b/src/gentropy/config.py @@ -121,10 +121,16 @@ class GWASCatalogSumstatsPreprocessConfig(StepConfig): class EqtlCatalogueConfig(StepConfig): """eQTL Catalogue step configuration.""" + session: Any = field( + default_factory=lambda: { + "start_hail": True, + } + ) eqtl_catalogue_paths_imported: str = MISSING eqtl_catalogue_study_index_out: str = MISSING eqtl_catalogue_credible_sets_out: str = MISSING mqtl_quantification_methods_blacklist: list[str] = field(default_factory=lambda: []) + eqtl_lead_pvalue_threshold: float = 1e-3 _target_: str = "gentropy.eqtl_catalogue.EqtlCatalogueStep" @@ -168,6 +174,7 @@ class FinngenFinemappingConfig(StepConfig): _target_: str = ( "gentropy.finngen_finemapping_ingestion.FinnGenFinemappingIngestionStep" ) + finngen_finemapping_lead_pvalue_threshold: float = 1e-5 @dataclass diff --git a/src/gentropy/dataset/dataset.py b/src/gentropy/dataset/dataset.py index c822b592a..d033e129d 100644 --- a/src/gentropy/dataset/dataset.py +++ b/src/gentropy/dataset/dataset.py @@ -13,7 +13,7 @@ from pyspark.sql.window import Window from typing_extensions import Self -from gentropy.common.schemas import flatten_schema +from gentropy.common.schemas import SchemaValidationError, compare_struct_schemas if TYPE_CHECKING: from enum import Enum @@ -142,57 +142,15 @@ def validate_schema(self: Dataset) -> None: """Validate DataFrame schema against expected class schema. Raises: - ValueError: DataFrame schema is not valid + SchemaValidationError: If the DataFrame schema does not match the expected schema """ expected_schema = self._schema - expected_fields = flatten_schema(expected_schema) observed_schema = self._df.schema - observed_fields = flatten_schema(observed_schema) # Unexpected fields in dataset - if unexpected_field_names := [ - x.name - for x in observed_fields - if x.name not in [y.name for y in expected_fields] - ]: - raise ValueError( - f"The {unexpected_field_names} fields are not included in DataFrame schema: {expected_fields}" - ) - - # Required fields not in dataset - required_fields = [x.name for x in expected_schema if not x.nullable] - if missing_required_fields := [ - req - for req in required_fields - if not any(field.name == req for field in observed_fields) - ]: - raise ValueError( - f"The {missing_required_fields} fields are required but missing: {required_fields}" - ) - - # Fields with duplicated names - if duplicated_fields := [ - x for x in set(observed_fields) if observed_fields.count(x) > 1 - ]: - raise ValueError( - f"The following fields are duplicated in DataFrame schema: {duplicated_fields}" - ) - - # Fields with different datatype - observed_field_types = { - field.name: type(field.dataType) for field in observed_fields - } - expected_field_types = { - field.name: type(field.dataType) for field in expected_fields - } - if fields_with_different_observed_datatype := [ - name - for name, observed_type in observed_field_types.items() - if name in expected_field_types - and observed_type != expected_field_types[name] - ]: - raise ValueError( - f"The following fields present differences in their datatypes: {fields_with_different_observed_datatype}." + if discrepancies := compare_struct_schemas(observed_schema, expected_schema): + raise SchemaValidationError( + f"Schema validation failed for {type(self).__name__}", discrepancies ) def valid_rows(self: Self, invalid_flags: list[str], invalid: bool = False) -> Self: diff --git a/src/gentropy/dataset/variant_index.py b/src/gentropy/dataset/variant_index.py index 1cc1eac1b..2f24cd985 100644 --- a/src/gentropy/dataset/variant_index.py +++ b/src/gentropy/dataset/variant_index.py @@ -6,9 +6,11 @@ from typing import TYPE_CHECKING import pyspark.sql.functions as f +import pyspark.sql.types as t from gentropy.common.schemas import parse_spark_schema from gentropy.common.spark_helpers import ( + get_nested_struct_schema, get_record_with_maximum_value, normalise_column, rename_all_columns, @@ -131,6 +133,7 @@ def add_annotation( # Prefix for renaming columns: prefix = "annotation_" + # Generate select expressions that to merge and import columns from annotation: select_expressions = [] @@ -141,10 +144,17 @@ def add_annotation( # If an annotation column can be found in both datasets: if (column in self.df.columns) and (column in annotation_source.df.columns): # Arrays are merged: - if "ArrayType" in field.dataType.__str__(): + if isinstance(field.dataType, t.ArrayType): + fields_order = None + if isinstance(field.dataType.elementType, t.StructType): + # Extract the schema of the array to get the order of the fields: + array_schema = [ + field for field in VariantIndex.get_schema().fields if field.name == column + ][0].dataType + fields_order = get_nested_struct_schema(array_schema).fieldNames() select_expressions.append( safe_array_union( - f.col(column), f.col(f"{prefix}{column}") + f.col(column), f.col(f"{prefix}{column}"), fields_order ).alias(column) ) # Non-array columns are coalesced: diff --git a/src/gentropy/datasource/open_targets/variants.py b/src/gentropy/datasource/open_targets/variants.py index 03018438b..5b6822ae6 100644 --- a/src/gentropy/datasource/open_targets/variants.py +++ b/src/gentropy/datasource/open_targets/variants.py @@ -95,7 +95,6 @@ def as_vcf_df( variant_df = variant_df.withColumn( col, create_empty_column_if_not_exists(col) ) - return ( variant_df.filter(f.col("variantId").isNotNull()) .withColumn( diff --git a/src/gentropy/eqtl_catalogue.py b/src/gentropy/eqtl_catalogue.py index 7adc5d8a2..3ad61ddea 100644 --- a/src/gentropy/eqtl_catalogue.py +++ b/src/gentropy/eqtl_catalogue.py @@ -3,6 +3,7 @@ from __future__ import annotations from gentropy.common.session import Session +from gentropy.config import EqtlCatalogueConfig from gentropy.datasource.eqtl_catalogue.finemapping import EqtlCatalogueFinemapping from gentropy.datasource.eqtl_catalogue.study_index import EqtlCatalogueStudyIndex @@ -20,6 +21,7 @@ def __init__( eqtl_catalogue_paths_imported: str, eqtl_catalogue_study_index_out: str, eqtl_catalogue_credible_sets_out: str, + eqtl_lead_pvalue_threshold: float = EqtlCatalogueConfig().eqtl_lead_pvalue_threshold, ) -> None: """Run eQTL Catalogue ingestion step. @@ -29,6 +31,7 @@ def __init__( eqtl_catalogue_paths_imported (str): Input eQTL Catalogue fine mapping results path. eqtl_catalogue_study_index_out (str): Output eQTL Catalogue study index path. eqtl_catalogue_credible_sets_out (str): Output eQTL Catalogue credible sets path. + eqtl_lead_pvalue_threshold (float, optional): Lead p-value threshold. Defaults to EqtlCatalogueConfig().eqtl_lead_pvalue_threshold. """ # Extract studies_metadata = EqtlCatalogueStudyIndex.read_studies_from_source( @@ -58,13 +61,19 @@ def __init__( processed_susie_df = EqtlCatalogueFinemapping.parse_susie_results( credible_sets_df, lbf_df, studies_metadata ) - credible_sets = EqtlCatalogueFinemapping.from_susie_results(processed_susie_df) - study_index = EqtlCatalogueStudyIndex.from_susie_results(processed_susie_df) - # Load - study_index.df.write.mode(session.write_mode).parquet( - eqtl_catalogue_study_index_out + ( + EqtlCatalogueStudyIndex.from_susie_results(processed_susie_df) + # Writing the output: + .df.write.mode(session.write_mode) + .parquet(eqtl_catalogue_study_index_out) ) - credible_sets.df.write.mode(session.write_mode).parquet( - eqtl_catalogue_credible_sets_out + + ( + EqtlCatalogueFinemapping.from_susie_results(processed_susie_df) + # Flagging sub-significnat loci: + .validate_lead_pvalue(pvalue_cutoff=eqtl_lead_pvalue_threshold) + # Writing the output: + .df.write.mode(session.write_mode) + .parquet(eqtl_catalogue_credible_sets_out) ) diff --git a/src/gentropy/finngen_finemapping_ingestion.py b/src/gentropy/finngen_finemapping_ingestion.py index 80089cf68..ca5ca1656 100644 --- a/src/gentropy/finngen_finemapping_ingestion.py +++ b/src/gentropy/finngen_finemapping_ingestion.py @@ -20,6 +20,7 @@ def __init__( finngen_finemapping_out: str, finngen_susie_finemapping_snp_files: str = FinngenFinemappingConfig().finngen_susie_finemapping_snp_files, finngen_susie_finemapping_cs_summary_files: str = FinngenFinemappingConfig().finngen_susie_finemapping_cs_summary_files, + finngen_finemapping_lead_pvalue_threshold: float = FinngenFinemappingConfig().finngen_finemapping_lead_pvalue_threshold, ) -> None: """Run FinnGen finemapping ingestion step. @@ -28,16 +29,21 @@ def __init__( finngen_finemapping_out (str): Output path for the finemapping results in StudyLocus format. finngen_susie_finemapping_snp_files(str): Path to the FinnGen SuSIE finemapping results. finngen_susie_finemapping_cs_summary_files (str): FinnGen SuSIE summaries for CS filters(LBF>2). + finngen_finemapping_lead_pvalue_threshold (float): Lead p-value threshold. """ # Read finemapping outputs from the input paths. - finngen_finemapping_df = FinnGenFinemapping.from_finngen_susie_finemapping( - spark=session.spark, - finngen_susie_finemapping_snp_files=finngen_susie_finemapping_snp_files, - finngen_susie_finemapping_cs_summary_files=finngen_susie_finemapping_cs_summary_files, - ) - - # Write the output. - finngen_finemapping_df.df.write.mode(session.write_mode).parquet( - finngen_finemapping_out + ( + FinnGenFinemapping.from_finngen_susie_finemapping( + spark=session.spark, + finngen_susie_finemapping_snp_files=finngen_susie_finemapping_snp_files, + finngen_susie_finemapping_cs_summary_files=finngen_susie_finemapping_cs_summary_files, + ) + # Flagging sub-significnat loci: + .validate_lead_pvalue( + pvalue_cutoff=finngen_finemapping_lead_pvalue_threshold + ) + # Writing the output: + .df.write.mode(session.write_mode) + .parquet(finngen_finemapping_out) ) diff --git a/src/gentropy/pics.py b/src/gentropy/pics.py index e80a37eb6..f96f54997 100644 --- a/src/gentropy/pics.py +++ b/src/gentropy/pics.py @@ -3,6 +3,7 @@ from __future__ import annotations from gentropy.common.session import Session +from gentropy.config import WindowBasedClumpingStepConfig from gentropy.dataset.study_locus import CredibleInterval, StudyLocus from gentropy.method.pics import PICS @@ -28,8 +29,14 @@ def __init__( session, study_locus_ld_annotated_in ) # PICS - picsed_sl = PICS.finemap(study_locus_ld_annotated).filter_credible_set( - credible_interval=CredibleInterval.IS99 + ( + PICS.finemap(study_locus_ld_annotated) + .filter_credible_set(credible_interval=CredibleInterval.IS99) + # Flagging sub-significnat loci: + .validate_lead_pvalue( + pvalue_cutoff=WindowBasedClumpingStepConfig().gwas_significance + ) + # Writing the output: + .df.write.mode(session.write_mode) + .parquet(picsed_study_locus_out) ) - # Write - picsed_sl.df.write.mode(session.write_mode).parquet(picsed_study_locus_out) diff --git a/tests/gentropy/common/test_schema_methods.py b/tests/gentropy/common/test_schema_methods.py new file mode 100644 index 000000000..8ed1342b5 --- /dev/null +++ b/tests/gentropy/common/test_schema_methods.py @@ -0,0 +1,303 @@ +"""Tests methods dealing with schema comparison.""" + +from __future__ import annotations + +from collections import defaultdict + +from pyspark.sql.types import ( + ArrayType, + IntegerType, + StringType, + StructField, + StructType, +) + +from gentropy.common.schemas import ( + compare_array_schemas, + compare_struct_schemas, +) + + +class TestSchemaComparisonMethods: + """Class for testing schema comparison methods.""" + + STRUCT_FIELD_STRING = StructField("a", StringType(), True) + STRUCT_FIELD_STRING_MANDATORY = StructField("a", StringType(), False) + STRUCT_FIELD_INTEGER = StructField("b", IntegerType(), True) + STRUCT_FIELD_WRONGTYPE = StructField("a", IntegerType(), True) + + def test_struct_validation_return_type(self: TestSchemaComparisonMethods) -> None: + """Test successful validation of StructType.""" + observed = StructType([self.STRUCT_FIELD_STRING, self.STRUCT_FIELD_INTEGER]) + expected = StructType([self.STRUCT_FIELD_STRING, self.STRUCT_FIELD_INTEGER]) + + discrepancy = compare_struct_schemas(observed, expected) + assert isinstance(discrepancy, defaultdict) + + def test_struct_validation_success(self: TestSchemaComparisonMethods) -> None: + """Test successful validation of StructType.""" + observed = StructType([self.STRUCT_FIELD_STRING, self.STRUCT_FIELD_INTEGER]) + expected = StructType([self.STRUCT_FIELD_STRING, self.STRUCT_FIELD_INTEGER]) + + discrepancy = compare_struct_schemas(observed, expected) + assert not discrepancy + + def test_struct_validation_non_matching_type( + self: TestSchemaComparisonMethods, + ) -> None: + """Test unsuccessful validation of StructType.""" + observed = StructType([self.STRUCT_FIELD_STRING]) + expected = StructType([self.STRUCT_FIELD_WRONGTYPE]) + + discrepancy = compare_struct_schemas(observed, expected) + + # Test there's a discrepancy: + assert discrepancy + + # Test that the discrepancy is in the field name: + assert "columns_with_non_matching_type" in discrepancy + + def test_struct_validation_missing_mandatory( + self: TestSchemaComparisonMethods, + ) -> None: + """Test unsuccessful validation of StructType.""" + observed = StructType([self.STRUCT_FIELD_INTEGER]) + expected = StructType( + [self.STRUCT_FIELD_STRING_MANDATORY, self.STRUCT_FIELD_INTEGER] + ) + + discrepancy = compare_struct_schemas(observed, expected) + + # Test there's a discrepancy: + assert discrepancy + + # Test that the discrepancy is in the field name: + assert "missing_mandatory_columns" in discrepancy + + # Test that the right column is flagged as missing: + assert ( + self.STRUCT_FIELD_STRING_MANDATORY.name + in discrepancy["missing_mandatory_columns"] + ) + + def test_struct_validation_unexpected_column( + self: TestSchemaComparisonMethods, + ) -> None: + """Test unsuccessful validation of StructType.""" + observed = StructType( + [self.STRUCT_FIELD_STRING_MANDATORY, self.STRUCT_FIELD_INTEGER] + ) + expected = StructType([self.STRUCT_FIELD_STRING_MANDATORY]) + + discrepancy = compare_struct_schemas(observed, expected) + + # Test there's a discrepancy: + assert discrepancy + + # Test that the discrepancy is in the field name: + assert "unexpected_columns" in discrepancy + + # Test that the right column is flagged as unexpected: + assert self.STRUCT_FIELD_INTEGER.name in discrepancy["unexpected_columns"] + + def test_struct_validation_duplicated_columns( + self: TestSchemaComparisonMethods, + ) -> None: + """Test unsuccessful validation of StructType.""" + observed = StructType( + [ + self.STRUCT_FIELD_STRING, + self.STRUCT_FIELD_STRING, + self.STRUCT_FIELD_INTEGER, + ] + ) + expected = StructType([self.STRUCT_FIELD_STRING, self.STRUCT_FIELD_INTEGER]) + + discrepancy = compare_struct_schemas(observed, expected) + + # Test there's a discrepancy: + assert discrepancy + + # Test that the discrepancy is in the field name: + assert "duplicated_columns" in discrepancy + + # Test that the right column is flagged as duplicated: + assert self.STRUCT_FIELD_STRING.name in discrepancy["duplicated_columns"] + + def test_struct_validation_success_nested_struct( + self: TestSchemaComparisonMethods, + ) -> None: + """Test successful validation of nested StructType.""" + nested_struct = StructType( + [self.STRUCT_FIELD_STRING, self.STRUCT_FIELD_INTEGER] + ) + + observed = StructType([StructField("c", nested_struct)]) + expected = StructType([StructField("c", nested_struct)]) + + discrepancy = compare_struct_schemas(observed, expected) + assert not discrepancy + + def test_struct_validation_non_matching_type_nested_struct( + self: TestSchemaComparisonMethods, + ) -> None: + """Test unsuccessful validation of nested StructType.""" + nested_struct = StructType([self.STRUCT_FIELD_STRING]) + + observed = StructType([StructField("c", nested_struct)]) + expected = StructType( + [StructField("c", StructType([self.STRUCT_FIELD_WRONGTYPE]))] + ) + + discrepancy = compare_struct_schemas(observed, expected) + + # Test there's a discrepancy: + assert discrepancy + + # Test that the discrepancy is in the field name: + assert "columns_with_non_matching_type" in discrepancy + + def test_array_validation_success(self: TestSchemaComparisonMethods) -> None: + """Test successful validation of ArrayType.""" + observed = ArrayType(StringType()) + expected = ArrayType(StringType()) + + discrepancy = compare_array_schemas(observed, expected) + assert not discrepancy + + def test_array_validation_non_matching_type( + self: TestSchemaComparisonMethods, + ) -> None: + """Test unsuccessful validation of ArrayType.""" + observed = ArrayType(StringType()) + expected = ArrayType(IntegerType()) + + discrepancy = compare_array_schemas(observed, expected) + + # Test there's a discrepancy: + assert discrepancy + + # Test that the discrepancy is in the field name: + assert "columns_with_non_matching_type" in discrepancy + + def test_array_validation_nested_array(self: TestSchemaComparisonMethods) -> None: + """Test successful validation of nested ArrayType.""" + nested_array = ArrayType(StringType()) + + observed = ArrayType(nested_array) + expected = ArrayType(nested_array) + + discrepancy = compare_array_schemas(observed, expected) + assert not discrepancy + + def test_array_validation_non_matching_type_nested_array( + self: TestSchemaComparisonMethods, + ) -> None: + """Test unsuccessful validation of nested ArrayType.""" + observed = ArrayType(ArrayType(StringType())) + expected = ArrayType(ArrayType(IntegerType())) + + discrepancy = compare_array_schemas(observed, expected) + + # Test there's a discrepancy: + assert discrepancy + + # Test that the discrepancy is in the field name: + assert "columns_with_non_matching_type" in discrepancy + + def test_struct_validation_success_nested_with_array( + self: TestSchemaComparisonMethods, + ) -> None: + """Test successful validation of nested StructType with ArrayType.""" + nested_array = StructField("a", ArrayType(StringType()), True) + nested_struct = StructType([self.STRUCT_FIELD_STRING, nested_array]) + + observed = StructType([StructField("c", nested_struct, True)]) + expected = StructType([StructField("c", nested_struct, True)]) + + discrepancy = compare_struct_schemas(observed, expected) + assert not discrepancy + + def test_struct_validation_non_matching_type_nested_with_array( + self: TestSchemaComparisonMethods, + ) -> None: + """Test unsuccessful validation of nested StructType with ArrayType.""" + nested_array = StructField("a", ArrayType(StringType()), True) + nested_array_wrong_type = StructField("a", ArrayType(IntegerType()), True) + nested_struct = StructType([self.STRUCT_FIELD_STRING, nested_array]) + nested_struct_wrong_type = StructType( + [self.STRUCT_FIELD_STRING, nested_array_wrong_type] + ) + observed = StructType([StructField("c", nested_struct, True)]) + expected = StructType([StructField("c", nested_struct_wrong_type, True)]) + + discrepancy = compare_struct_schemas(observed, expected) + + # Test there's a discrepancy: + assert discrepancy + + # Test that the discrepancy is in the field name: + assert "columns_with_non_matching_type" in discrepancy + + def test_struct_validation_failing_with_multiple_reasons( + self: TestSchemaComparisonMethods, + ) -> None: + """Test unsuccessful validation of StructType with multiple issues.""" + observed = StructType( + [ + StructField( + "a", + ArrayType( + ArrayType( + StructType( + [ + StructField("a", IntegerType(), False), + StructField("c", StringType(), True), + StructField("c", StringType(), True), + ] + ), + False, + ), + False, + ), + False, + ), + ] + ) + + expected = StructType( + [ + StructField( + "a", + ArrayType( + ArrayType( + StructType( + [ + StructField("b", IntegerType(), False), + StructField("c", StringType(), True), + StructField("d", StringType(), True), + ] + ), + False, + ), + False, + ), + False, + ), + ] + ) + + discrepancy = compare_struct_schemas(observed, expected) + + # Test there's a discrepancy: + assert discrepancy + + # Test if the returned list of discrepancies is correct: + assert discrepancy == defaultdict( + list, + { + "duplicated_columns": ["a[][].c"], + "missing_mandatory_columns": ["a[][].b"], + "unexpected_columns": ["a[][].a"], + }, + ) diff --git a/tests/gentropy/dataset/test_study_index.py b/tests/gentropy/dataset/test_study_index.py index 3bdd7a5cb..fee3a2557 100644 --- a/tests/gentropy/dataset/test_study_index.py +++ b/tests/gentropy/dataset/test_study_index.py @@ -167,14 +167,14 @@ def _setup(self: TestGeneValidation, spark: SparkSession) -> None: """Setup fixture.""" self.study_index = StudyIndex( _df=spark.createDataFrame(self.STUDY_DATA, self.STUDY_COLUMNS).withColumn( - "qualityControls", f.array() + "qualityControls", f.array().cast("array") ), _schema=StudyIndex.get_schema(), ) self.study_index_no_gene = StudyIndex( _df=spark.createDataFrame(self.STUDY_DATA, self.STUDY_COLUMNS) - .withColumn("qualityControls", f.array()) + .withColumn("qualityControls", f.array().cast("array")) .drop("geneId"), _schema=StudyIndex.get_schema(), ) @@ -231,7 +231,7 @@ def _setup(self: TestUniquenessValidation, spark: SparkSession) -> None: """Setup fixture.""" self.study_index = StudyIndex( _df=spark.createDataFrame(self.STUDY_DATA, self.STUDY_COLUMNS).withColumn( - "qualityControls", f.array() + "qualityControls", f.array().cast("array") ), _schema=StudyIndex.get_schema(), ) @@ -279,7 +279,7 @@ def _setup(self: TestStudyTypeValidation, spark: SparkSession) -> None: """Setup fixture.""" self.study_index = StudyIndex( _df=spark.createDataFrame(self.STUDY_DATA, self.STUDY_COLUMNS).withColumn( - "qualityControls", f.array() + "qualityControls", f.array().cast("array") ), _schema=StudyIndex.get_schema(), ) @@ -346,8 +346,10 @@ def _setup(self: TestDiseaseValidation, spark: SparkSession) -> None: spark.createDataFrame(self.STUDY_DATA, self.STUDY_COLUMNS) .groupBy("studyId", "studyType", "projectId") .agg(f.collect_set("efo").alias("traitFromSourceMappedIds")) - .withColumn("qualityControls", f.array()) - .withColumn("backgroundTraitFromSourceMappedIds", f.array()) + .withColumn("qualityControls", f.array().cast("array")) + .withColumn( + "backgroundTraitFromSourceMappedIds", f.array().cast("array") + ) ) study_df.show() # Mock study index: diff --git a/tests/gentropy/method/test_clump.py b/tests/gentropy/method/test_clump.py index 83a95e19f..ed07608db 100644 --- a/tests/gentropy/method/test_clump.py +++ b/tests/gentropy/method/test_clump.py @@ -135,7 +135,9 @@ def test_flagging(self: TestIsLeadLinked) -> None: """Test flagging of lead variants.""" # Create the study locus and clump: sl_flagged = StudyLocus( - _df=self.df.drop("expected_flag").withColumn("qualityControls", f.array()), + _df=self.df.drop("expected_flag").withColumn( + "qualityControls", f.array().cast("array") + ), _schema=StudyLocus.get_schema(), ).clump() diff --git a/tests/gentropy/test_schemas.py b/tests/gentropy/test_schemas.py index 1af72c149..6840e3207 100644 --- a/tests/gentropy/test_schemas.py +++ b/tests/gentropy/test_schemas.py @@ -12,6 +12,8 @@ import pytest from pyspark.sql.types import StructType +from gentropy.common.schemas import SchemaValidationError + if TYPE_CHECKING: from _pytest.fixtures import FixtureRequest @@ -90,7 +92,7 @@ def test_validate_schema_extra_field( mock_dataset_instance: V2G | GeneIndex, ) -> None: """Test that validate_schema raises an error if the observed schema has an extra field.""" - with pytest.raises(ValueError, match="extraField"): + with pytest.raises(SchemaValidationError, match="extraField"): mock_dataset_instance.df = mock_dataset_instance.df.withColumn( "extraField", f.lit("extra") ) @@ -103,7 +105,7 @@ def test_validate_schema_missing_field( mock_dataset_instance: V2G | GeneIndex, ) -> None: """Test that validate_schema raises an error if the observed schema is missing a required field, geneId in this case.""" - with pytest.raises(ValueError, match="geneId"): + with pytest.raises(SchemaValidationError, match="geneId"): mock_dataset_instance.df = mock_dataset_instance.df.drop("geneId") @pytest.mark.parametrize( @@ -114,7 +116,7 @@ def test_validate_schema_duplicated_field( mock_dataset_instance: V2G | GeneIndex, ) -> None: """Test that validate_schema raises an error if the observed schema has a duplicated field, geneId in this case.""" - with pytest.raises(ValueError, match="geneId"): + with pytest.raises(SchemaValidationError, match="geneId"): mock_dataset_instance.df = mock_dataset_instance.df.select( "*", f.lit("A").alias("geneId") ) @@ -127,7 +129,7 @@ def test_validate_schema_different_datatype( mock_dataset_instance: V2G | GeneIndex, ) -> None: """Test that validate_schema raises an error if any field in the observed schema has a different type than expected.""" - with pytest.raises(ValueError, match="geneId"): + with pytest.raises(SchemaValidationError, match="geneId"): mock_dataset_instance.df = mock_dataset_instance.df.withColumn( "geneId", f.lit(1) )