Skip to content

Commit

Permalink
feat(ld_annotator): optional r2 threshold (#648)
Browse files Browse the repository at this point in the history
* feat(ld_annotator): apply r2 threshold

* feat(ld_annotator): apply r2 threshold

* chore(ldannotator): change threshold to 0.5
  • Loading branch information
ireneisdoomed authored Jun 18, 2024
1 parent ca377ce commit 6d93192
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 8 deletions.
27 changes: 25 additions & 2 deletions src/gentropy/dataset/study_locus.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,25 @@ def filter_credible_set(
)
return self

@staticmethod
def filter_ld_set(ld_set: Column, r2_threshold: float) -> Column:
"""Filter the LD set by a given R2 threshold.
Args:
ld_set (Column): LD set
r2_threshold (float): R2 threshold to filter the LD set on
Returns:
Column: Filtered LD index
"""
return f.when(
ld_set.isNotNull(),
f.filter(
ld_set,
lambda tag: tag["r2Overall"] >= r2_threshold,
),
)

def find_overlaps(
self: StudyLocus, study_index: StudyIndex, intra_study_overlap: bool = False
) -> StudyLocusOverlap:
Expand Down Expand Up @@ -524,20 +543,24 @@ def annotate_locus_statistics(
return self

def annotate_ld(
self: StudyLocus, study_index: StudyIndex, ld_index: LDIndex
self: StudyLocus,
study_index: StudyIndex,
ld_index: LDIndex,
r2_threshold: float = 0.0,
) -> StudyLocus:
"""Annotate LD information to study-locus.
Args:
study_index (StudyIndex): Study index to resolve ancestries.
ld_index (LDIndex): LD index to resolve LD information.
r2_threshold (float): R2 threshold to filter the LD index. Default is 0.0.
Returns:
StudyLocus: Study locus annotated with ld information from LD index.
"""
from gentropy.method.ld import LDAnnotator

return LDAnnotator.ld_annotate(self, study_index, ld_index)
return LDAnnotator.ld_annotate(self, study_index, ld_index, r2_threshold)

def clump(self: StudyLocus) -> StudyLocus:
"""Perform LD clumping of the studyLocus.
Expand Down
12 changes: 12 additions & 0 deletions src/gentropy/method/ld.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Performing linkage disequilibrium (LD) operations."""

from __future__ import annotations

from typing import TYPE_CHECKING
Expand Down Expand Up @@ -120,6 +121,7 @@ def ld_annotate(
associations: StudyLocus,
studies: StudyIndex,
ld_index: LDIndex,
r2_threshold: float = 0.5,
) -> StudyLocus:
"""Annotate linkage disequilibrium (LD) information to a set of studyLocus.
Expand All @@ -131,10 +133,14 @@ def ld_annotate(
5. Flags associations with variants that are not found in the LD reference
6. Rescues lead variant when no LD information is available but lead variant is available
!!! note
Because the LD index has a pre-set threshold of R2 = 0.5, this is the minimum threshold for the LD information to be included in the ldSet.
Args:
associations (StudyLocus): Dataset to be LD annotated
studies (StudyIndex): Dataset with study information
ld_index (LDIndex): Dataset with LD information for every variant present in LD matrix
r2_threshold (float): R2 threshold to filter the LD set on. Default is 0.5.
Returns:
StudyLocus: including additional column with LD information.
Expand Down Expand Up @@ -175,6 +181,12 @@ def ld_annotate(
),
)
.drop("ldPopulationStructure")
# Filter the LD set by the R2 threshold and set to null if no LD information passes the threshold
.withColumn(
"ldSet",
StudyLocus.filter_ld_set(f.col("ldSet"), r2_threshold),
)
.withColumn("ldSet", f.when(f.size("ldSet") > 0, f.col("ldSet")))
# QC: Flag associations with variants that are not found in the LD reference
.withColumn(
"qualityControls",
Expand Down
21 changes: 15 additions & 6 deletions tests/gentropy/dataset/test_study_locus.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from gentropy.dataset.study_locus import CredibleInterval, StudyLocus
from gentropy.dataset.study_locus_overlap import StudyLocusOverlap
from gentropy.dataset.summary_statistics import SummaryStatistics
from pyspark.sql import Column, SparkSession
from pyspark.sql import Column, Row, SparkSession
from pyspark.sql.types import (
ArrayType,
BooleanType,
Expand All @@ -23,11 +23,6 @@
)


def test_study_locus_creation(mock_study_locus: StudyLocus) -> None:
"""Test study locus creation with mock data."""
assert isinstance(mock_study_locus, StudyLocus)


@pytest.mark.parametrize(
"has_overlap, expected",
[
Expand Down Expand Up @@ -531,3 +526,17 @@ def test_ldannotate(
assert isinstance(
mock_study_locus.annotate_ld(mock_study_index, mock_ld_index), StudyLocus
)


def test_filter_ld_set(spark: SparkSession) -> None:
"""Test filter_ld_set."""
observed_data = [
Row(studyLocusId="sl1", ldSet=[{"tagVariantId": "tag1", "r2Overall": 0.4}])
]
observed_df = spark.createDataFrame(
observed_data, ["studyLocusId", "ldSet"]
).withColumn("ldSet", StudyLocus.filter_ld_set(f.col("ldSet"), 0.5))
expected_tags_in_ld = 0
assert (
observed_df.filter(f.size("ldSet") > 1).count() == expected_tags_in_ld
), "Expected tags in ld set differ from observed."

0 comments on commit 6d93192

Please sign in to comment.