diff --git a/src/gentropy/dataset/study_locus_overlap.py b/src/gentropy/dataset/study_locus_overlap.py index 5f839bd9c..4115c9c68 100644 --- a/src/gentropy/dataset/study_locus_overlap.py +++ b/src/gentropy/dataset/study_locus_overlap.py @@ -4,10 +4,13 @@ from dataclasses import dataclass from typing import TYPE_CHECKING +import pyspark.sql.functions as f + from gentropy.common.schemas import parse_spark_schema from gentropy.dataset.dataset import Dataset if TYPE_CHECKING: + from pyspark.sql import DataFrame from pyspark.sql.types import StructType from gentropy.dataset.study_index import StudyIndex @@ -49,6 +52,34 @@ def from_associations( """ return study_locus.find_overlaps(study_index) + + def calculate_beta_ratio(self: StudyLocusOverlap) -> DataFrame: + """Calculate the beta ratio for the overlapping signals. + + Returns: + DataFrame: A dataframe containing left and right loci IDs, chromosome + and the average sign of the beta ratio + """ + expanded_overlaps = self.df.select("*", "statistics.*").drop("statistics") + + # Drop any rows where the beta is null + both_betas_not_null_overlaps = (expanded_overlaps + .filter(f.col("right_beta").isNotNull()) + .filter(f.col("left_beta").isNotNull())) + + # Calculate the beta ratio and get the sign, then calculate the average sign across all variants in the locus + beta_ratio_sign = (both_betas_not_null_overlaps + .withColumn("beta_ratio_sign", + f.signum(f.col("left_beta") / f.col("right_beta"))) + .groupBy("leftStudyLocusId", + "rightStudyLocusId", + "chromosome") + .agg(f.avg("beta_ratio_sign").alias("beta_ratio_sign_avg"))) + + # Remove any rows where the average sign is not 1 or -1 + return beta_ratio_sign.filter(f.abs(f.col("beta_ratio_sign_avg") != 1)) + + def _convert_to_square_matrix(self: StudyLocusOverlap) -> StudyLocusOverlap: """Convert the dataset to a square matrix.