Skip to content

Commit

Permalink
feat: flag PICS top hits in studies with credset sumstats (#777)
Browse files Browse the repository at this point in the history
Co-authored-by: Daniel Suveges <daniel.suveges@protonmail.com>
  • Loading branch information
d0choa and DSuveges authored Sep 20, 2024
1 parent 018defa commit ad3f503
Show file tree
Hide file tree
Showing 3 changed files with 114 additions and 0 deletions.
42 changes: 42 additions & 0 deletions src/gentropy/dataset/study_locus.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ class StudyLocusQualityCheck(Enum):
INVALID_VARIANT_IDENTIFIER (str): Flagging study loci where identifier of any tagging variant was not found in the variant index
TOP_HIT (str): Study locus from curated top hit
IN_MHC (str): Flagging study loci in the MHC region
REDUNDANT_PICS_TOP_HIT (str): Flagging study loci in studies with PICS results from summary statistics
"""

SUBSIGNIFICANT_FLAG = "Subsignificant p-value"
Expand All @@ -74,6 +75,9 @@ class StudyLocusQualityCheck(Enum):
"Some variant identifiers of this locus were not found in variant index"
)
IN_MHC = "MHC region"
REDUNDANT_PICS_TOP_HIT = (
"PICS results from summary statistics available for this same study"
)
TOP_HIT = "Study locus from curated top hit"


Expand Down Expand Up @@ -878,6 +882,44 @@ def qc_MHC_region(self: StudyLocus) -> StudyLocus:
)
return self

def qc_redundant_top_hits_from_PICS(self: StudyLocus) -> StudyLocus:
"""Flag associations from top hits when the study contains other PICS associations from summary statistics.
This flag can be useful to identify top hits that should be explained by other associations in the study derived from the summary statistics.
Returns:
StudyLocus: Updated study locus with redundant top hits flagged.
"""
studies_with_pics_sumstats = (
self.df.filter(f.col("finemappingMethod") == "pics")
# Returns True if the study contains any PICS associations from summary statistics
.withColumn(
"hasPicsSumstats",
~f.array_contains(
"qualityControls", StudyLocusQualityCheck.TOP_HIT.value
),
)
.groupBy("studyId")
.agg(f.max(f.col("hasPicsSumstats")).alias("studiesWithPicsSumstats"))
)

return StudyLocus(
_df=self.df.join(studies_with_pics_sumstats, on="studyId", how="left")
.withColumn(
"qualityControls",
self.update_quality_flag(
f.col("qualityControls"),
f.array_contains(
"qualityControls", StudyLocusQualityCheck.TOP_HIT.value
)
& f.col("studiesWithPicsSumstats"),
StudyLocusQualityCheck.REDUNDANT_PICS_TOP_HIT,
),
)
.drop("studiesWithPicsSumstats"),
_schema=StudyLocus.get_schema(),
)

def _qc_no_population(self: StudyLocus) -> StudyLocus:
"""Flag associations where the study doesn't have population information to resolve LD.
Expand Down
1 change: 1 addition & 0 deletions src/gentropy/study_locus_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def __init__(
# Add flag for MHC region
.qc_MHC_region()
.validate_study(study_index) # Flagging studies not in study index
.qc_redundant_top_hits_from_PICS() # Flagging top hits from studies with PICS summary statistics
.validate_unique_study_locus_id() # Flagging duplicated study locus ids
).persist() # we will need this for 2 types of outputs

Expand Down
71 changes: 71 additions & 0 deletions tests/gentropy/dataset/test_study_locus.py
Original file line number Diff line number Diff line change
Expand Up @@ -778,3 +778,74 @@ def test_study_validation_correctness(self: TestStudyLocusValidation) -> None:
)
.count()
) == 1


class TestStudyLocusRedundancyFlagging:
"""Collection of tests related to flagging redundant credible sets."""

STUDY_LOCUS_DATA = [
(1, "v1", "s1", "pics", []),
(2, "v2", "s1", "pics", [StudyLocusQualityCheck.TOP_HIT.value]),
(3, "v3", "s1", "pics", []),
(3, "v3", "s1", "pics", []),
(1, "v1", "s1", "pics", [StudyLocusQualityCheck.TOP_HIT.value]),
(1, "v1", "s2", "pics", [StudyLocusQualityCheck.TOP_HIT.value]),
(1, "v1", "s2", "pics", [StudyLocusQualityCheck.TOP_HIT.value]),
(1, "v1", "s3", "SuSie", []),
(1, "v1", "s3", "pics", [StudyLocusQualityCheck.TOP_HIT.value]),
(1, "v1", "s4", "pics", []),
(1, "v1", "s4", "SuSie", []),
(1, "v1", "s4", "pics", [StudyLocusQualityCheck.TOP_HIT.value]),
]

STUDY_LOCUS_SCHEMA = t.StructType(
[
t.StructField("studyLocusId", t.LongType(), False),
t.StructField("variantId", t.StringType(), False),
t.StructField("studyId", t.StringType(), False),
t.StructField("finemappingMethod", t.StringType(), False),
t.StructField("qualityControls", t.ArrayType(t.StringType()), False),
]
)

@pytest.fixture(autouse=True)
def _setup(self: TestStudyLocusRedundancyFlagging, spark: SparkSession) -> None:
"""Setup study locus for testing."""
self.study_locus = StudyLocus(
_df=spark.createDataFrame(
self.STUDY_LOCUS_DATA, schema=self.STUDY_LOCUS_SCHEMA
),
_schema=StudyLocus.get_schema(),
)

def test_qc_redundant_top_hits_from_PICS_returntype(
self: TestStudyLocusRedundancyFlagging,
) -> None:
"""Test qc_redundant_top_hits_from_PICS."""
assert isinstance(
self.study_locus.qc_redundant_top_hits_from_PICS(), StudyLocus
)

def test_qc_redundant_top_hits_from_PICS_no_data_loss(
self: TestStudyLocusRedundancyFlagging,
) -> None:
"""Testing if the redundancy flagging returns the same number of rows."""
assert (
self.study_locus.qc_redundant_top_hits_from_PICS().df.count()
== self.study_locus.df.count()
)

def test_qc_redundant_top_hits_from_PICS_correctness(
self: TestStudyLocusRedundancyFlagging,
) -> None:
"""Testing if the study validation flags the right number of studies."""
assert (
self.study_locus.qc_redundant_top_hits_from_PICS()
.df.filter(
f.array_contains(
f.col("qualityControls"),
StudyLocusQualityCheck.REDUNDANT_PICS_TOP_HIT.value,
)
)
.count()
) == 3

0 comments on commit ad3f503

Please sign in to comment.