diff --git a/src/gentropy/dataset/study_locus.py b/src/gentropy/dataset/study_locus.py index a0a231cfa..e685d828f 100644 --- a/src/gentropy/dataset/study_locus.py +++ b/src/gentropy/dataset/study_locus.py @@ -82,6 +82,7 @@ class StudyLocusQualityCheck(Enum): IN_MHC (str): Flagging study loci in the MHC region REDUNDANT_PICS_TOP_HIT (str): Flagging study loci in studies with PICS results from summary statistics EXPLAINED_BY_SUSIE (str): Study locus in region explained by a SuSiE credible set + ABNORMAL_PIPS (str): Flagging study loci with a sum of PIPs that are not in [0.99,1] OUT_OF_SAMPLE_LD (str): Study locus finemapped without in-sample LD reference INVALID_CHROMOSOME (str): Chromosome not in 1:22, X, Y, XY or MT """ @@ -113,6 +114,7 @@ class StudyLocusQualityCheck(Enum): TOP_HIT = "Study locus from curated top hit" EXPLAINED_BY_SUSIE = "Study locus in region explained by a SuSiE credible set" OUT_OF_SAMPLE_LD = "Study locus finemapped without in-sample LD reference" + ABNORMAL_PIPS = "Study locus with a sum of PIPs that not in the expected range [0.99,1]" INVALID_CHROMOSOME = "Chromosome not in 1:22, X, Y, XY or MT" @@ -391,6 +393,55 @@ def _qc_subsignificant_associations( StudyLocusQualityCheck.SUBSIGNIFICANT_FLAG, ) + def qc_abnormal_pips( + self: StudyLocus, + sum_pips_lower_threshold: float = 0.99, + sum_pips_upper_threshold: float = 1.0001, # Set slightly above 1 to account for floating point errors + ) -> StudyLocus: + """Filter study-locus by sum of posterior inclusion probabilities to ensure that the sum of PIPs is within a given range. + + Args: + sum_pips_lower_threshold (float): Lower threshold for the sum of PIPs. + sum_pips_upper_threshold (float): Upper threshold for the sum of PIPs. + + Returns: + StudyLocus: Filtered study-locus dataset. + """ + # QC column might not be present so we have to be ready to handle it: + qc_select_expression = ( + f.col("qualityControls") + if "qualityControls" in self.df.columns + else f.lit(None).cast(ArrayType(StringType())) + ) + + flag = (self.df.withColumn( + "sumPosteriorProbability", + f.aggregate( + f.col("locus"), + f.lit(0.0), + lambda acc, x: acc + x["posteriorProbability"] + )).withColumn( + "pipOutOfRange", + f.when( + (f.col("sumPosteriorProbability") < sum_pips_lower_threshold) | + (f.col("sumPosteriorProbability") > sum_pips_upper_threshold), + True + ).otherwise(False))) + + return StudyLocus( + _df=(flag + # Flagging loci with failed studies: + .withColumn( + "qualityControls", + self.update_quality_flag( + qc_select_expression, + f.col("pipOutOfRange"), + StudyLocusQualityCheck.ABNORMAL_PIPS + ), + ).drop("sumPosteriorProbability", "pipOutOfRange")), + _schema=self.get_schema() + ) + @staticmethod def _overlapping_peaks( credset_to_overlap: DataFrame, intra_study_overlap: bool = False diff --git a/src/gentropy/study_locus_validation.py b/src/gentropy/study_locus_validation.py index bca6b8e11..1c8ae161c 100644 --- a/src/gentropy/study_locus_validation.py +++ b/src/gentropy/study_locus_validation.py @@ -45,6 +45,8 @@ def __init__( .annotate_study_type(study_index) # Add study type to study locus .qc_redundant_top_hits_from_PICS() # Flagging top hits from studies with PICS summary statistics .qc_explained_by_SuSiE() # Flagging credible sets in regions explained by SuSiE + # Flagging credible sets with PIP > 1 or PIP < 0.99 + .qc_abnormal_pips(sum_pips_lower_threshold=0.99,sum_pips_upper_threshold=1.0001) # Annotates credible intervals and filter to only keep 99% credible sets .filter_credible_set(credible_interval=CredibleInterval.IS99) # Annotate credible set confidence: diff --git a/tests/gentropy/dataset/test_study_locus.py b/tests/gentropy/dataset/test_study_locus.py index 3cbaf6866..7f15a11a6 100644 --- a/tests/gentropy/dataset/test_study_locus.py +++ b/tests/gentropy/dataset/test_study_locus.py @@ -203,72 +203,84 @@ def test_filter_credible_set(mock_study_locus: StudyLocus) -> None: ) -@pytest.mark.parametrize( - ("observed", "expected"), +def test_qc_abnormal_pips(mock_study_locus: StudyLocus) -> None: + """Test that the qc_abnormal_pips method returns a StudyLocus object.""" + assert isinstance(mock_study_locus.qc_abnormal_pips(0.99, 1), StudyLocus) + + +# Used primarily for test_unique_variants_in_locus but also for other tests +test_unique_variants_in_locus_test_data = [ + ( + # Locus is not null, should return union between variants in locus and lead variant + [ + ( + "1", + "traitA", + "22_varA", + [ + {"variantId": "22_varA", "posteriorProbability": 0.44}, + {"variantId": "22_varB", "posteriorProbability": 0.015}, + ], + ), + ], + [ + ( + "22_varA", + "22", + ), + ( + "22_varB", + "22", + ), + ], + ), + ( + # locus is null, should return lead variant + [ + ("1", "traitA", "22_varA", None), + ], + [ + ( + "22_varA", + "22", + ), + ], + ), +] + +test_unique_variants_in_locus_test_schema = StructType( [ - ( - # Locus is not null, should return union between variants in locus and lead variant - [ - ( - "1", - "traitA", - "22_varA", + StructField("studyLocusId", StringType(), True), + StructField("studyId", StringType(), True), + StructField("variantId", StringType(), True), + StructField( + "locus", + ArrayType( + StructType( [ - {"variantId": "22_varA", "posteriorProbability": 0.44}, - {"variantId": "22_varB", "posteriorProbability": 0.015}, - ], - ), - ], - [ - ( - "22_varA", - "22", - ), - ( - "22_varB", - "22", - ), - ], - ), - ( - # locus is null, should return lead variant - [ - ("1", "traitA", "22_varA", None), - ], - [ - ( - "22_varA", - "22", - ), - ], + StructField("variantId", StringType(), True), + StructField("posteriorProbability", DoubleType(), True), + ] + ) + ), + True, ), - ], + ] +) + + +@pytest.mark.parametrize( + ("observed", "expected"), + test_unique_variants_in_locus_test_data, ) def test_unique_variants_in_locus( spark: SparkSession, observed: list[Any], expected: list[Any] ) -> None: """Test unique variants in locus.""" # assert isinstance(mock_study_locus.test_unique_variants_in_locus(), DataFrame) - schema = StructType( - [ - StructField("studyLocusId", StringType(), True), - StructField("studyId", StringType(), True), - StructField("variantId", StringType(), True), - StructField( - "locus", - ArrayType( - StructType( - [ - StructField("variantId", StringType(), True), - ] - ) - ), - True, - ), - ] - ) data_sl = StudyLocus( - _df=spark.createDataFrame(observed, schema), _schema=StudyLocus.get_schema() + _df=spark.createDataFrame(observed, test_unique_variants_in_locus_test_schema), + _schema=StudyLocus.get_schema(), ) expected_df = spark.createDataFrame( expected, schema="variantId: string, chromosome: string" @@ -286,187 +298,223 @@ def test_clump(mock_study_locus: StudyLocus) -> None: assert isinstance(mock_study_locus.clump(), StudyLocus) -@pytest.mark.parametrize( - ("observed", "expected"), +# Used primarily for test_annotate_credible_sets but also for other tests +test_annotate_credible_sets_test_data = [ + ( + # Simple case + [ + # Observed + ( + "1", + "traitA", + "leadB", + [{"variantId": "tagVariantA", "posteriorProbability": 1.0}], + ), + ], + [ + # Expected + ( + "1", + "traitA", + "leadB", + [ + { + "variantId": "tagVariantA", + "posteriorProbability": 1.0, + "is95CredibleSet": True, + "is99CredibleSet": True, + } + ], + ) + ], + ), + ( + # Unordered credible set + [ + # Observed + ( + "1", + "traitA", + "leadA", + [ + {"variantId": "tagVariantA", "posteriorProbability": 0.44}, + {"variantId": "tagVariantB", "posteriorProbability": 0.015}, + {"variantId": "tagVariantC", "posteriorProbability": 0.04}, + {"variantId": "tagVariantD", "posteriorProbability": 0.005}, + {"variantId": "tagVariantE", "posteriorProbability": 0.5}, + {"variantId": "tagVariantNull", "posteriorProbability": None}, + {"variantId": "tagVariantNull", "posteriorProbability": None}, + ], + ) + ], + [ + # Expected + ( + "1", + "traitA", + "leadA", + [ + { + "variantId": "tagVariantE", + "posteriorProbability": 0.5, + "is95CredibleSet": True, + "is99CredibleSet": True, + }, + { + "variantId": "tagVariantA", + "posteriorProbability": 0.44, + "is95CredibleSet": True, + "is99CredibleSet": True, + }, + { + "variantId": "tagVariantC", + "posteriorProbability": 0.04, + "is95CredibleSet": True, + "is99CredibleSet": True, + }, + { + "variantId": "tagVariantB", + "posteriorProbability": 0.015, + "is95CredibleSet": False, + "is99CredibleSet": True, + }, + { + "variantId": "tagVariantD", + "posteriorProbability": 0.005, + "is95CredibleSet": False, + "is99CredibleSet": False, + }, + { + "variantId": "tagVariantNull", + "posteriorProbability": None, + "is95CredibleSet": False, + "is99CredibleSet": False, + }, + { + "variantId": "tagVariantNull", + "posteriorProbability": None, + "is95CredibleSet": False, + "is99CredibleSet": False, + }, + ], + ) + ], + ), + ( + # Null credible set + [ + # Observed + ( + "1", + "traitA", + "leadB", + None, + ), + ], + [ + # Expected + ( + "1", + "traitA", + "leadB", + None, + ) + ], + ), + ( + # Empty credible set + [ + # Observed + ( + "1", + "traitA", + "leadB", + [], + ), + ], + [ + # Expected + ( + "1", + "traitA", + "leadB", + None, + ) + ], + ), +] +test_annotate_credible_sets_test_schema = StructType( [ - ( - # Simple case - [ - # Observed - ( - "1", - "traitA", - "leadB", - [{"variantId": "tagVariantA", "posteriorProbability": 1.0}], - ), - ], - [ - # Expected - ( - "1", - "traitA", - "leadB", - [ - { - "variantId": "tagVariantA", - "posteriorProbability": 1.0, - "is95CredibleSet": True, - "is99CredibleSet": True, - } - ], - ) - ], - ), - ( - # Unordered credible set - [ - # Observed - ( - "1", - "traitA", - "leadA", - [ - {"variantId": "tagVariantA", "posteriorProbability": 0.44}, - {"variantId": "tagVariantB", "posteriorProbability": 0.015}, - {"variantId": "tagVariantC", "posteriorProbability": 0.04}, - {"variantId": "tagVariantD", "posteriorProbability": 0.005}, - {"variantId": "tagVariantE", "posteriorProbability": 0.5}, - {"variantId": "tagVariantNull", "posteriorProbability": None}, - {"variantId": "tagVariantNull", "posteriorProbability": None}, - ], - ) - ], - [ - # Expected - ( - "1", - "traitA", - "leadA", + StructField("studyLocusId", StringType(), True), + StructField("studyId", StringType(), True), + StructField("variantId", StringType(), True), + StructField( + "locus", + ArrayType( + StructType( [ - { - "variantId": "tagVariantE", - "posteriorProbability": 0.5, - "is95CredibleSet": True, - "is99CredibleSet": True, - }, - { - "variantId": "tagVariantA", - "posteriorProbability": 0.44, - "is95CredibleSet": True, - "is99CredibleSet": True, - }, - { - "variantId": "tagVariantC", - "posteriorProbability": 0.04, - "is95CredibleSet": True, - "is99CredibleSet": True, - }, - { - "variantId": "tagVariantB", - "posteriorProbability": 0.015, - "is95CredibleSet": False, - "is99CredibleSet": True, - }, - { - "variantId": "tagVariantD", - "posteriorProbability": 0.005, - "is95CredibleSet": False, - "is99CredibleSet": False, - }, - { - "variantId": "tagVariantNull", - "posteriorProbability": None, - "is95CredibleSet": False, - "is99CredibleSet": False, - }, - { - "variantId": "tagVariantNull", - "posteriorProbability": None, - "is95CredibleSet": False, - "is99CredibleSet": False, - }, - ], + StructField("variantId", StringType(), True), + StructField("posteriorProbability", DoubleType(), True), + StructField("is95CredibleSet", BooleanType(), True), + StructField("is99CredibleSet", BooleanType(), True), + ] ) - ], - ), - ( - # Null credible set - [ - # Observed - ( - "1", - "traitA", - "leadB", - None, - ), - ], - [ - # Expected - ( - "1", - "traitA", - "leadB", - None, - ) - ], - ), - ( - # Empty credible set - [ - # Observed - ( - "1", - "traitA", - "leadB", - [], - ), - ], - [ - # Expected - ( - "1", - "traitA", - "leadB", - None, - ) - ], + ), + True, ), - ], + ] +) + + +@pytest.mark.parametrize( + ("observed", "expected"), + test_annotate_credible_sets_test_data, ) def test_annotate_credible_sets( spark: SparkSession, observed: list[Any], expected: list[Any] ) -> None: """Test annotate_credible_sets.""" - schema = StructType( - [ - StructField("studyLocusId", StringType(), True), - StructField("studyId", StringType(), True), - StructField("variantId", StringType(), True), - StructField( - "locus", - ArrayType( - StructType( - [ - StructField("variantId", StringType(), True), - StructField("posteriorProbability", DoubleType(), True), - StructField("is95CredibleSet", BooleanType(), True), - StructField("is99CredibleSet", BooleanType(), True), - ] - ) - ), - True, - ), - ] - ) data_sl = StudyLocus( - _df=spark.createDataFrame(observed, schema), _schema=StudyLocus.get_schema() + _df=spark.createDataFrame(observed, test_annotate_credible_sets_test_schema), + _schema=StudyLocus.get_schema(), ) expected_sl = StudyLocus( - _df=spark.createDataFrame(expected, schema), _schema=StudyLocus.get_schema() + _df=spark.createDataFrame(expected, test_annotate_credible_sets_test_schema), + _schema=StudyLocus.get_schema(), ) assert data_sl.annotate_credible_sets().df.collect() == expected_sl.df.collect() +def test_qc_abnormal_pips_good_locus(spark: SparkSession) -> None: + """Test qc_abnormal_pips with a well-behaving locus.""" + # Input data + sl = StudyLocus( + _df=spark.createDataFrame( + test_annotate_credible_sets_test_data[1][0], + test_annotate_credible_sets_test_schema, + ), + _schema=StudyLocus.get_schema(), + ) + assert ( + sl.qc_abnormal_pips().df.filter(f.size("qualityControls") > 0).count() == 0 + ), "Expected number of rows differ from observed." + + +def test_qc_abnormal_pips_bad_locus(spark: SparkSession) -> None: + """Test qc_abnormal_pips with an abnormal locus.""" + # Input data + sl = StudyLocus( + _df=spark.createDataFrame( + test_unique_variants_in_locus_test_data[0][0], + test_unique_variants_in_locus_test_schema, + ), + _schema=StudyLocus.get_schema(), + ) + assert ( + sl.qc_abnormal_pips().df.filter(f.size("qualityControls") > 0).count() == 1 + ), "Expected number of rows differ from observed." + + def test_annotate_ld( mock_study_locus: StudyLocus, mock_study_index: StudyIndex, mock_ld_index: LDIndex ) -> None: