From 50d98ed05769e95c31fc1f773851db0a2dbdcf2c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Irene=20L=C3=B3pez?= Date: Tue, 3 Sep 2024 12:28:44 +0100 Subject: [PATCH 01/48] refactor(L2GFeatureMatrix): remove schema validation --- .../assets/schemas/l2g_feature_matrix.json | 155 ------------------ src/gentropy/dataset/l2g_feature_matrix.py | 74 +++++---- src/gentropy/dataset/l2g_prediction.py | 27 ++- src/gentropy/l2g.py | 9 +- src/gentropy/method/l2g/model.py | 2 +- src/gentropy/method/l2g/trainer.py | 4 +- tests/gentropy/conftest.py | 38 +---- tests/gentropy/dataset/test_l2g.py | 55 +++---- 8 files changed, 88 insertions(+), 276 deletions(-) delete mode 100644 src/gentropy/assets/schemas/l2g_feature_matrix.json diff --git a/src/gentropy/assets/schemas/l2g_feature_matrix.json b/src/gentropy/assets/schemas/l2g_feature_matrix.json deleted file mode 100644 index 114936168..000000000 --- a/src/gentropy/assets/schemas/l2g_feature_matrix.json +++ /dev/null @@ -1,155 +0,0 @@ -{ - "fields": [ - { - "metadata": {}, - "name": "studyLocusId", - "nullable": false, - "type": "long" - }, - { - "metadata": {}, - "name": "geneId", - "nullable": false, - "type": "string" - }, - { - "metadata": {}, - "name": "goldStandardSet", - "nullable": true, - "type": "string" - }, - { - "metadata": {}, - "name": "distanceTssMean", - "nullable": true, - "type": "float" - }, - { - "metadata": {}, - "name": "distanceTssMinimum", - "nullable": true, - "type": "float" - }, - { - "metadata": {}, - "name": "vepMaximumNeighborhood", - "nullable": true, - "type": "float" - }, - { - "metadata": {}, - "name": "vepMaximum", - "nullable": true, - "type": "float" - }, - { - "metadata": {}, - "name": "vepMeanNeighborhood", - "nullable": true, - "type": "float" - }, - { - "metadata": {}, - "name": "vepMean", - "nullable": true, - "type": "float" - }, - { - "metadata": {}, - "name": "eqtlColocClppMaximum", - "nullable": true, - "type": "float" - }, - { - "metadata": {}, - "name": "eqtlColocClppMaximumNeighborhood", - "nullable": true, - "type": "float" - }, - { - "metadata": {}, - "name": "eqtlColocLlrMaximum", - "nullable": true, - "type": "float" - }, - { - "metadata": {}, - "name": "eqtlColocLlrMaximumNeighborhood", - "nullable": true, - "type": "float" - }, - { - "metadata": {}, - "name": "pqtlColocClppMaximum", - "nullable": true, - "type": "float" - }, - { - "metadata": {}, - "name": "pqtlColocClppMaximumNeighborhood", - "nullable": true, - "type": "float" - }, - { - "metadata": {}, - "name": "pqtlColocLlrMaximum", - "nullable": true, - "type": "float" - }, - { - "metadata": {}, - "name": "pqtlColocLlrMaximumNeighborhood", - "nullable": true, - "type": "float" - }, - { - "metadata": {}, - "name": "sqtlColocClppMaximum", - "nullable": true, - "type": "float" - }, - { - "metadata": {}, - "name": "sqtlColocClppMaximumNeighborhood", - "nullable": true, - "type": "float" - }, - { - "metadata": {}, - "name": "sqtlColocLlrMaximum", - "nullable": true, - "type": "float" - }, - { - "metadata": {}, - "name": "sqtlColocLlrMaximumNeighborhood", - "nullable": true, - "type": "float" - }, - { - "metadata": {}, - "name": "tuqtlColocClppMaximum", - "nullable": true, - "type": "float" - }, - { - "metadata": {}, - "name": "tuqtlColocClppMaximumNeighborhood", - "nullable": true, - "type": "float" - }, - { - "metadata": {}, - "name": "tuqtlColocLlrMaximum", - "nullable": true, - "type": "float" - }, - { - "metadata": {}, - "name": "tuqtlColocLlrMaximumNeighborhood", - "nullable": true, - "type": "float" - } - ], - "type": "struct" -} diff --git a/src/gentropy/dataset/l2g_feature_matrix.py b/src/gentropy/dataset/l2g_feature_matrix.py index 4c611e3da..2098c3cb4 100644 --- a/src/gentropy/dataset/l2g_feature_matrix.py +++ b/src/gentropy/dataset/l2g_feature_matrix.py @@ -2,17 +2,14 @@ from __future__ import annotations -from dataclasses import dataclass, field from functools import reduce from typing import TYPE_CHECKING, Type -from gentropy.common.schemas import parse_spark_schema from gentropy.common.spark_helpers import convert_from_long_to_wide -from gentropy.dataset.dataset import Dataset from gentropy.method.l2g.feature_factory import ColocalisationFactory, StudyLocusFactory if TYPE_CHECKING: - from pyspark.sql.types import StructType + from pyspark.sql import DataFrame from gentropy.dataset.colocalisation import Colocalisation from gentropy.dataset.study_index import StudyIndex @@ -20,34 +17,42 @@ from gentropy.dataset.v2g import V2G -@dataclass -class L2GFeatureMatrix(Dataset): - """Dataset with features for Locus to Gene prediction. +class L2GFeatureMatrix: + """Dataset with features for Locus to Gene prediction.""" - Attributes: - features_list (list[str] | None): List of features to use. If None, all possible features are used. - fixed_cols (list[str]): Columns that should be kept fixed in the feature matrix, although not considered as features. - mode (str): Mode of the feature matrix. Defaults to "train". Can be either "train" or "predict". - """ - - features_list: list[str] | None = None - fixed_cols: list[str] = field(default_factory=lambda: ["studyLocusId", "geneId"]) - mode: str = "train" - - def __post_init__(self: L2GFeatureMatrix) -> None: + def __init__( + self, + _df: DataFrame, + features_list: list[str] | None = None, + mode: str = "train", + ) -> None: """Post-initialisation to set the features list. If not provided, all columns except the fixed ones are used. + Args: + _df (DataFrame): Feature matrix dataset + features_list (list[str] | None): List of features to use. If None, all possible features are used. + mode (str): Mode of the feature matrix. Defaults to "train". Can be either "train" or "predict". + Raises: ValueError: If the mode is neither 'train' nor 'predict'. """ - if self.mode not in ["train", "predict"]: + if mode not in ["train", "predict"]: raise ValueError("Mode should be either 'train' or 'predict'") - if self.mode == "train": - self.fixed_cols = self.fixed_cols + ["goldStandardSet"] - self.features_list = self.features_list or [ - col for col in self._df.columns if col not in self.fixed_cols + + self.fixed_cols = ["studyLocusId", "geneId"] + if mode == "train": + self.fixed_cols.append("goldStandardSet") + + self.features_list = features_list or [ + col for col in _df.columns if col not in self.fixed_cols ] - self.validate_schema() + self._df = _df.selectExpr( + self.fixed_cols + + [ + f"CAST({feature} AS FLOAT) AS {feature}" + for feature in self.features_list + ] + ) @classmethod def generate_features( @@ -95,19 +100,9 @@ def generate_features( _df=convert_from_long_to_wide( fm, ["studyLocusId", "geneId"], "featureName", "featureValue" ), - _schema=cls.get_schema(), features_list=features_list, ) - @classmethod - def get_schema(cls: type[L2GFeatureMatrix]) -> StructType: - """Provides the schema for the L2gFeatureMatrix dataset. - - Returns: - StructType: Schema for the L2gFeatureMatrix dataset - """ - return parse_spark_schema("l2g_feature_matrix.json") - def calculate_feature_missingness_rate( self: L2GFeatureMatrix, ) -> dict[str, float]: @@ -145,7 +140,7 @@ def fill_na( Returns: L2GFeatureMatrix: L2G feature matrix dataset """ - self.df = self._df.fillna(value, subset=subset) + self._df = self._df.fillna(value, subset=subset) return self def select_features( @@ -164,6 +159,13 @@ def select_features( ValueError: If no features have been selected. """ if features_list := features_list or self.features_list: - self.df = self._df.select(self.fixed_cols + features_list) + # cast to float every feature in the features_list + self._df = self._df.selectExpr( + self.fixed_cols + + [ + f"CAST({feature} AS FLOAT) AS {feature}" + for feature in features_list + ] + ) return self raise ValueError("features_list cannot be None") diff --git a/src/gentropy/dataset/l2g_prediction.py b/src/gentropy/dataset/l2g_prediction.py index 9895f55b7..724ada584 100644 --- a/src/gentropy/dataset/l2g_prediction.py +++ b/src/gentropy/dataset/l2g_prediction.py @@ -83,22 +83,17 @@ def from_credible_set( colocalisation=coloc, ).fill_na() - gwas_fm = ( - L2GFeatureMatrix( - _df=( - fm.df.join( - credible_set.filter_by_study_type( - "gwas", study_index - ).df.select("studyLocusId"), - on="studyLocusId", - ) - ), - _schema=L2GFeatureMatrix.get_schema(), - mode="predict", - ) - .select_features(features_list) - .persist() - ) + gwas_fm = L2GFeatureMatrix( + _df=( + fm._df.join( + credible_set.filter_by_study_type("gwas", study_index).df.select( + "studyLocusId" + ), + on="studyLocusId", + ) + ), + mode="predict", + ).select_features(features_list) return ( l2g_model.predict(gwas_fm, session), gwas_fm, diff --git a/src/gentropy/l2g.py b/src/gentropy/l2g.py index 432e46f88..444be4ef5 100644 --- a/src/gentropy/l2g.py +++ b/src/gentropy/l2g.py @@ -126,7 +126,7 @@ def run_predict(self) -> None: download_from_hub=self.download_from_hub, ) if self.feature_matrix_path: - feature_matrix.df.write.mode(self.session.write_mode).parquet( + feature_matrix._df.write.mode(self.session.write_mode).parquet( self.feature_matrix_path ) predictions.df.write.mode(self.session.write_mode).parquet( @@ -152,7 +152,7 @@ def run_train(self) -> None: wandb_key = access_gcp_secret("wandb-key", "open-targets-genetics-dev") # Process gold standard and L2G features - data = self._generate_feature_matrix().persist() + data = self._generate_feature_matrix() # Instantiate classifier and train model l2g_model = LocusToGeneModel( @@ -173,7 +173,7 @@ def run_train(self) -> None: # we upload the model in the filesystem self.model_path.split("/")[-1], hf_hub_token, - data=trained_model.training_data.df.drop( + data=trained_model.training_data._df.drop( "goldStandardSet", "geneId" ).toPandas(), repo_id=self.hf_hub_repo_id, @@ -227,14 +227,13 @@ def _generate_feature_matrix(self) -> L2GFeatureMatrix: return ( L2GFeatureMatrix( - _df=fm.df.join( + _df=fm._df.join( f.broadcast( gold_standards.df.drop("variantId", "studyId", "sources") ), on=["studyLocusId", "geneId"], how="inner", ), - _schema=L2GFeatureMatrix.get_schema(), ) .fill_na() .select_features(self.features_list) diff --git a/src/gentropy/method/l2g/model.py b/src/gentropy/method/l2g/model.py index e0d9e42fb..6e0b0fda1 100644 --- a/src/gentropy/method/l2g/model.py +++ b/src/gentropy/method/l2g/model.py @@ -114,7 +114,7 @@ def predict( pd_dataframe.iteritems = pd_dataframe.items - feature_matrix_pdf = feature_matrix.df.toPandas() + feature_matrix_pdf = feature_matrix._df.toPandas() # L2G score is the probability the classifier assigns to the positive class (the second element in the probability array) feature_matrix_pdf["score"] = self.model.predict_proba( # We drop the fixed columns to only pass the feature values to the classifier diff --git a/src/gentropy/method/l2g/trainer.py b/src/gentropy/method/l2g/trainer.py index 85fedc45b..69dfb24ff 100644 --- a/src/gentropy/method/l2g/trainer.py +++ b/src/gentropy/method/l2g/trainer.py @@ -134,7 +134,7 @@ def log_to_wandb( run.log({"f1": f1_score(self.y_test, y_predicted, average="weighted")}) # Track gold standards and their features run.log( - {"featureMatrix": Table(dataframe=self.feature_matrix.df.toPandas())} + {"featureMatrix": Table(dataframe=self.feature_matrix._df.toPandas())} ) # Log feature missingness run.log( @@ -155,7 +155,7 @@ def train( Returns: LocusToGeneModel: Fitted model """ - data_df = self.feature_matrix.df.drop("geneId").toPandas() + data_df = self.feature_matrix._df.drop("geneId").toPandas() # Encode labels in `goldStandardSet` to a numeric value data_df["goldStandardSet"] = data_df["goldStandardSet"].map( diff --git a/tests/gentropy/conftest.py b/tests/gentropy/conftest.py index 9ae7ace58..18425077b 100644 --- a/tests/gentropy/conftest.py +++ b/tests/gentropy/conftest.py @@ -579,37 +579,17 @@ def sample_otp_interactions(spark: SparkSession) -> DataFrame: @pytest.fixture() def mock_l2g_feature_matrix(spark: SparkSession) -> L2GFeatureMatrix: """Mock l2g feature matrix dataset.""" - schema = L2GFeatureMatrix.get_schema() - - data_spec = ( - dg.DataGenerator( - spark, - rows=50, - partitions=4, - randomSeedMethod="hash_fieldname", - ) - .withSchema(schema) - .withColumnSpec("distanceTssMean", percentNulls=0.1) - .withColumnSpec("distanceTssMinimum", percentNulls=0.1) - .withColumnSpec("eqtlColocClppMaximum", percentNulls=0.1) - .withColumnSpec("eqtlColocClppMaximumNeighborhood", percentNulls=0.1) - .withColumnSpec("eqtlColocLlrMaximum", percentNulls=0.1) - .withColumnSpec("eqtlColocLlrMaximumNeighborhood", percentNulls=0.1) - .withColumnSpec("pqtlColocClppMaximum", percentNulls=0.1) - .withColumnSpec("pqtlColocClppMaximumNeighborhood", percentNulls=0.1) - .withColumnSpec("pqtlColocLlrMaximum", percentNulls=0.1) - .withColumnSpec("pqtlColocLlrMaximumNeighborhood", percentNulls=0.1) - .withColumnSpec("sqtlColocClppMaximum", percentNulls=0.1) - .withColumnSpec("sqtlColocClppMaximumNeighborhood", percentNulls=0.1) - .withColumnSpec("sqtlColocLlrMaximum", percentNulls=0.1) - .withColumnSpec("sqtlColocLlrMaximumNeighborhood", percentNulls=0.1) - .withColumnSpec( - "goldStandardSet", percentNulls=0.0, values=["positive", "negative"] - ) + return L2GFeatureMatrix( + _df=spark.createDataFrame( + [ + (1, "gene1", 100.0, None), + (2, "gene2", 1000.0, 0.0), + ], + "studyLocusId LONG, geneId STRING, distanceTssMean FLOAT, distanceTssMinimum FLOAT", + ), + mode="predict", ) - return L2GFeatureMatrix(_df=data_spec.build(), _schema=schema) - @pytest.fixture() def mock_l2g_gold_standard(spark: SparkSession) -> L2GGoldStandard: diff --git a/tests/gentropy/dataset/test_l2g.py b/tests/gentropy/dataset/test_l2g.py index d0f1c3672..496398945 100644 --- a/tests/gentropy/dataset/test_l2g.py +++ b/tests/gentropy/dataset/test_l2g.py @@ -4,7 +4,7 @@ from typing import TYPE_CHECKING -import pytest +from pyspark.sql.types import FloatType from gentropy.dataset.l2g_feature_matrix import L2GFeatureMatrix from gentropy.dataset.l2g_gold_standard import L2GGoldStandard @@ -34,7 +34,7 @@ def test_process_gene_interactions(sample_otp_interactions: DataFrame) -> None: ), "Gene interactions has a different schema." -def test_predictions(mock_l2g_predictions: L2GFeatureMatrix) -> None: +def test_predictions(mock_l2g_predictions: L2GPrediction) -> None: """Test L2G predictions creation with mock data.""" assert isinstance(mock_l2g_predictions, L2GPrediction) @@ -154,45 +154,36 @@ def test_remove_false_negatives(spark: SparkSession) -> None: assert observed_df.collect() == expected_df.collect() -def test_l2g_feature_constructor_with_schema_mismatch(spark: SparkSession) -> None: - """Test if provided shema mismatch results in error in L2GFeatureMatrix constructor. - - distanceTssMean is expected to be FLOAT by schema in src.gentropy.assets.schemas and is actualy DOUBLE. - """ - with pytest.raises(ValueError) as e: - L2GFeatureMatrix( - _df=spark.createDataFrame( - [ - (1, "gene1", 100.0), - (2, "gene2", 1000.0), - ], - "studyLocusId LONG, geneId STRING, distanceTssMean DOUBLE", - ), - _schema=L2GFeatureMatrix.get_schema(), - ) - assert e.value.args[0] == ( - "The following fields present differences in their datatypes: ['distanceTssMean']." - ) - - -def test_calculate_feature_missingness_rate(spark: SparkSession) -> None: - """Test L2GFeatureMatrix.calculate_feature_missingness_rate.""" +def test_l2g_feature_constructor_with_schema_mismatch( + spark: SparkSession, +) -> None: + """Test if provided schema mismatch is converted to right type in the L2GFeatureMatrix constructor.""" fm = L2GFeatureMatrix( _df=spark.createDataFrame( [ - (1, "gene1", 100.0, None), - (2, "gene2", 1000.0, 0.0), + (1, "gene1", 100.0), + (2, "gene2", 1000.0), ], - "studyLocusId LONG, geneId STRING, distanceTssMean FLOAT, distanceTssMinimum FLOAT", + "studyLocusId LONG, geneId STRING, distanceTssMean DOUBLE", ), - _schema=L2GFeatureMatrix.get_schema(), + mode="predict", ) + assert ( + fm._df.schema["distanceTssMean"].dataType == FloatType() + ), "Feature `distanceTssMean` is not being casted to FloatType. Check L2GFeatureMatrix constructor." + +def test_calculate_feature_missingness_rate( + spark: SparkSession, mock_l2g_feature_matrix: L2GFeatureMatrix +) -> None: + """Test L2GFeatureMatrix.calculate_feature_missingness_rate.""" expected_missingness = {"distanceTssMean": 0.0, "distanceTssMinimum": 1.0} - observed_missingness = fm.calculate_feature_missingness_rate() + observed_missingness = mock_l2g_feature_matrix.calculate_feature_missingness_rate() assert isinstance(observed_missingness, dict) - assert fm.features_list is not None and len(observed_missingness) == len( - fm.features_list + assert mock_l2g_feature_matrix.features_list is not None and len( + observed_missingness + ) == len( + mock_l2g_feature_matrix.features_list ), "Missing features in the missingness rate dictionary." assert ( observed_missingness == expected_missingness From e1f7c5ca21351837877129d7a8ce16dd1b2e25a4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Irene=20L=C3=B3pez?= Date: Tue, 3 Sep 2024 18:48:06 +0100 Subject: [PATCH 02/48] refactor(FeatureFactory): reshape feature generation WIP --- config/step/ot_locus_to_gene_train.yaml | 7 ++ src/gentropy/config.py | 49 +--------- src/gentropy/dataset/l2g_feature.py | 16 +++- src/gentropy/dataset/l2g_feature_matrix.py | 61 +++++-------- src/gentropy/l2g.py | 14 ++- src/gentropy/method/l2g/feature_factory.py | 100 ++++++++++++++++++++- 6 files changed, 148 insertions(+), 99 deletions(-) diff --git a/config/step/ot_locus_to_gene_train.yaml b/config/step/ot_locus_to_gene_train.yaml index b59a24dae..c6d1d7ffd 100644 --- a/config/step/ot_locus_to_gene_train.yaml +++ b/config/step/ot_locus_to_gene_train.yaml @@ -17,3 +17,10 @@ hyperparameters: max_depth: 5 loss: log_loss download_from_hub: true +features_list: +- name: distanceTssMean + # average distance of all tagging variants to gene TSS + path: ${datasets.variant_to_gene} +-name: distanceTssMinimum + # minimum distance of all tagging variants to gene TSS + path: ${datasets.variant_to_gene} diff --git a/src/gentropy/config.py b/src/gentropy/config.py index ed5fe4c81..eb09c7b2b 100644 --- a/src/gentropy/config.py +++ b/src/gentropy/config.py @@ -226,54 +226,7 @@ class LocusToGeneConfig(StepConfig): feature_matrix_path: str | None = None gold_standard_curation_path: str | None = None gene_interactions_path: str | None = None - features_list: list[str] = field( - default_factory=lambda: [ - # average distance of all tagging variants to gene TSS - "distanceTssMean", - # minimum distance of all tagging variants to gene TSS - "distanceTssMinimum", - # maximum vep consequence score of the locus 95% credible set among all genes in the vicinity - "vepMaximumNeighborhood", - # maximum vep consequence score of the locus 95% credible set split by gene - "vepMaximum", - # mean vep consequence score of the locus 95% credible set among all genes in the vicinity - "vepMeanNeighborhood", - # mean vep consequence score of the locus 95% credible set split by gene - "vepMean", - # max clpp for each (study, locus, gene) aggregating over all eQTLs - "eqtlColocClppMaximum", - # max clpp for each (study, locus) aggregating over all eQTLs - "eqtlColocClppMaximumNeighborhood", - # max clpp for each (study, locus, gene) aggregating over all pQTLs - "pqtlColocClppMaximum", - # max clpp for each (study, locus) aggregating over all pQTLs - "pqtlColocClppMaximumNeighborhood", - # max clpp for each (study, locus, gene) aggregating over all sQTLs - "sqtlColocClppMaximum", - # max clpp for each (study, locus) aggregating over all sQTLs - "sqtlColocClppMaximumNeighborhood", - # max clpp for each (study, locus) aggregating over all tuQTLs - "tuqtlColocClppMaximum", - # max clpp for each (study, locus, gene) aggregating over all tuQTLs - "tuqtlColocClppMaximumNeighborhood", - # max log-likelihood ratio value for each (study, locus, gene) aggregating over all eQTLs - "eqtlColocLlrMaximum", - # max log-likelihood ratio value for each (study, locus) aggregating over all eQTLs - "eqtlColocLlrMaximumNeighborhood", - # max log-likelihood ratio value for each (study, locus, gene) aggregating over all pQTLs - "pqtlColocLlrMaximum", - # max log-likelihood ratio value for each (study, locus) aggregating over all pQTLs - "pqtlColocLlrMaximumNeighborhood", - # max log-likelihood ratio value for each (study, locus, gene) aggregating over all sQTLs - "sqtlColocLlrMaximum", - # max log-likelihood ratio value for each (study, locus) aggregating over all sQTLs - "sqtlColocLlrMaximumNeighborhood", - # max log-likelihood ratio value for each (study, locus, gene) aggregating over all tuQTLs - "tuqtlColocLlrMaximum", - # max log-likelihood ratio value for each (study, locus) aggregating over all tuQTLs - "tuqtlColocLlrMaximumNeighborhood", - ] - ) + features_list: list[dict[str, str]] = MISSING hyperparameters: dict[str, Any] = field( default_factory=lambda: { "n_estimators": 100, diff --git a/src/gentropy/dataset/l2g_feature.py b/src/gentropy/dataset/l2g_feature.py index 2e9f19d61..0a78e222f 100644 --- a/src/gentropy/dataset/l2g_feature.py +++ b/src/gentropy/dataset/l2g_feature.py @@ -1,8 +1,10 @@ """L2G Feature Dataset.""" + from __future__ import annotations +from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any from gentropy.common.schemas import parse_spark_schema from gentropy.dataset.dataset import Dataset @@ -12,7 +14,7 @@ @dataclass -class L2GFeature(Dataset): +class L2GFeature(Dataset, ABC): """Locus-to-gene feature dataset.""" @classmethod @@ -23,3 +25,13 @@ def get_schema(cls: type[L2GFeature]) -> StructType: StructType: Schema for the L2GFeature dataset """ return parse_spark_schema("l2g_feature.json") + + @classmethod + @abstractmethod + def compute(cls: type[L2GFeature]) -> L2GFeature: + """Computes the L2GFeature dataset. + + Returns: + L2GFeature: a L2GFeature dataset + """ + pass diff --git a/src/gentropy/dataset/l2g_feature_matrix.py b/src/gentropy/dataset/l2g_feature_matrix.py index 2098c3cb4..c063e02dc 100644 --- a/src/gentropy/dataset/l2g_feature_matrix.py +++ b/src/gentropy/dataset/l2g_feature_matrix.py @@ -6,15 +6,12 @@ from typing import TYPE_CHECKING, Type from gentropy.common.spark_helpers import convert_from_long_to_wide -from gentropy.method.l2g.feature_factory import ColocalisationFactory, StudyLocusFactory +from gentropy.method.l2g.feature_factory import FeatureFactory if TYPE_CHECKING: from pyspark.sql import DataFrame - from gentropy.dataset.colocalisation import Colocalisation - from gentropy.dataset.study_index import StudyIndex - from gentropy.dataset.study_locus import StudyLocus - from gentropy.dataset.v2g import V2G + from gentropy.common.session import Session class L2GFeatureMatrix: @@ -55,52 +52,35 @@ def __init__( ) @classmethod - def generate_features( + def from_features_list( cls: Type[L2GFeatureMatrix], - features_list: list[str], - credible_set: StudyLocus, - study_index: StudyIndex, - variant_gene: V2G, - colocalisation: Colocalisation, + session: Session, + features_list: list[dict[str, str]], ) -> L2GFeatureMatrix: - """Generate features from the gentropy datasets. + """Generate features from the gentropy datasets by calling the feature factory that will instantiate the corresponding features. Args: - features_list (list[str]): List of features to generate - credible_set (StudyLocus): Credible set dataset - study_index (StudyIndex): Study index dataset - variant_gene (V2G): Variant to gene dataset - colocalisation (Colocalisation): Colocalisation dataset + session (Session): Session object + features_list (list[dict[str, str]]): List of objects with 2 keys corresponding to the features to generate: 'name' and 'path'. Returns: L2GFeatureMatrix: L2G feature matrix dataset - - Raises: - ValueError: If the feature matrix is empty """ - if features_dfs := [ - # Extract features - ColocalisationFactory._get_max_coloc_per_credible_set( - colocalisation, - credible_set, - study_index, - ).df, - StudyLocusFactory._get_tss_distance_features(credible_set, variant_gene).df, - StudyLocusFactory._get_vep_features(credible_set, variant_gene).df, - ]: - fm = reduce( - lambda x, y: x.unionByName(y), - features_dfs, - ) - else: - raise ValueError("No features found") - - # raise error if the feature matrix is empty + features_long_df = reduce( + lambda x, y: x.unionByName(y, allowMissingColumns=True), + [ + # Compute all features and merge them into a single dataframe + feature.df + for feature in FeatureFactory.generate_features(session, features_list) + ], + ) return cls( _df=convert_from_long_to_wide( - fm, ["studyLocusId", "geneId"], "featureName", "featureValue" + features_long_df, + ["studyLocusId", "geneId"], + "featureName", + "featureValue", ), - features_list=features_list, ) def calculate_feature_missingness_rate( @@ -169,3 +149,4 @@ def select_features( ) return self raise ValueError("features_list cannot be None") + raise ValueError("features_list cannot be None") diff --git a/src/gentropy/l2g.py b/src/gentropy/l2g.py index 444be4ef5..3ddacb81c 100644 --- a/src/gentropy/l2g.py +++ b/src/gentropy/l2g.py @@ -36,7 +36,7 @@ def __init__( study_index_path: str, gold_standard_curation_path: str, gene_interactions_path: str, - features_list: list[str], + features_list: list[dict[str, str]], hyperparameters: dict[str, Any], download_from_hub: bool, model_path: str | None, @@ -56,7 +56,7 @@ def __init__( study_index_path (str): Path to the study index dataset gold_standard_curation_path (str): Path to the gold standard curation dataset gene_interactions_path (str): Path to the gene interactions dataset - features_list (list[str]): List of features to use for the model + features_list (list[dict[str, str]]): List of features to use for the model. It is a list of objects with 2 keys: 'name' and 'path'. hyperparameters (dict[str, Any]): Hyperparameters for the model download_from_hub (bool): Whether to download the model from the Hugging Face Hub model_path (str | None): Path to the fitted model @@ -217,12 +217,10 @@ def _generate_feature_matrix(self) -> L2GFeatureMatrix: interactions=interactions, ) - fm = L2GFeatureMatrix.generate_features( - features_list=self.features_list, - credible_set=self.credible_set, - study_index=self.studies, - variant_gene=self.v2g, - colocalisation=self.coloc, + # TODO: Should StudyLocus and GoldStandard have an `annotate_w_features` method? + fm = L2GFeatureMatrix.from_features_list( + self.session, + self.features_list, ) return ( diff --git a/src/gentropy/method/l2g/feature_factory.py b/src/gentropy/method/l2g/feature_factory.py index 1158c6067..b31ff5c18 100644 --- a/src/gentropy/method/l2g/feature_factory.py +++ b/src/gentropy/method/l2g/feature_factory.py @@ -4,10 +4,11 @@ from functools import reduce from itertools import chain -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any import pyspark.sql.functions as f +from gentropy.common.session import Session from gentropy.common.spark_helpers import ( convert_from_wide_to_long, get_record_with_maximum_value, @@ -339,3 +340,100 @@ def _aggregate_vep_feature( ).filter(f.col("featureValue").isNotNull()), _schema=L2GFeature.get_schema(), ) + + +class DistanceTssMinimumFeature(L2GFeature): + """Minimum distance of all tagging variants to gene TSS.""" + + @classmethod + def compute( + cls: type[DistanceTssMinimumFeature], input_dependency: V2G + ) -> L2GFeature: + """Computes the feature. + + Args: + input_dependency (V2G): V2G dependency + + Returns: + L2GFeature: Feature dataset + + Raises: + NotImplementedError: Not implemented + """ + raise NotImplementedError + + +class DistanceTssMeanFeature(L2GFeature): + """Average distance of all tagging variants to gene TSS.""" + + @classmethod + def compute(cls: type[DistanceTssMeanFeature], input_dependency: V2G) -> L2GFeature: + """Computes the feature. + + Args: + input_dependency (V2G): V2G dependency + Returns: + L2GFeature: Feature dataset + Raises: + NotImplementedError: Not implemented + """ + raise NotImplementedError + + +class FeatureFactory: + """Factory class for creating features.""" + + # TODO: should this be live in the `features_list`? + feature_mapper = { + "distanceTssMinimum": DistanceTssMinimumFeature, + "distanceTssMean": DistanceTssMeanFeature, + } + + @classmethod + def generate_features( + cls: type[FeatureFactory], session: Session, features_list: list[dict[str, str]] + ) -> list[L2GFeature]: + """Generates a feature matrix by reading an object with instructions on how to create the features. + + Args: + session (Session): session object + features_list (list[dict[str, str]]): list of objects with 2 keys: 'name' and 'path'. + + Returns: + list[L2GFeature]: list of computed features. + """ + computed_features = [] + for feature in features_list: + input_dependency = cls.inject_dependency(session, feature["path"]) + computed_features.append( + cls.compute_feature(feature["name"], input_dependency) + ) + return computed_features + + @classmethod + def compute_feature( + cls: type[FeatureFactory], feature_name: str, input_dependency: Any + ) -> L2GFeature: + """Instantiates feature class. + + Args: + feature_name (str): name of the feature + input_dependency (Any): dependency object + Returns: + L2GFeature: instantiated feature object + """ + return cls.feature_mapper[feature_name].compute(input_dependency) + + @classmethod + def inject_dependency( + cls: type[FeatureFactory], session: Session, feature_dependency_path: str + ) -> Any: + """Injects a dependency into the feature factory. + + Args: + session (Session): session object + feature_dependency_path (str): path to the dependency of the feature + Returns: + Any: dependency object + """ + return V2G.from_parquet(session, feature_dependency_path) From a7757ac4b0fdd01d29a49aac0339e0a9ff9f9835 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 3 Sep 2024 17:48:39 +0000 Subject: [PATCH 03/48] chore: pre-commit auto fixes [...] --- src/gentropy/dataset/l2g_feature.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gentropy/dataset/l2g_feature.py b/src/gentropy/dataset/l2g_feature.py index 0a78e222f..366bbb88b 100644 --- a/src/gentropy/dataset/l2g_feature.py +++ b/src/gentropy/dataset/l2g_feature.py @@ -4,7 +4,7 @@ from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING from gentropy.common.schemas import parse_spark_schema from gentropy.dataset.dataset import Dataset From 8a70bf2c0d318da18e7aad5212f0d15a49e31485 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Irene=20L=C3=B3pez?= Date: Fri, 6 Sep 2024 12:01:17 +0100 Subject: [PATCH 04/48] chore: set l2gfeature properties with decorator --- src/gentropy/dataset/l2g_feature.py | 14 ++++ src/gentropy/method/l2g/feature_factory.py | 86 +++++++++++++++++----- 2 files changed, 83 insertions(+), 17 deletions(-) diff --git a/src/gentropy/dataset/l2g_feature.py b/src/gentropy/dataset/l2g_feature.py index 0a78e222f..4d47873f4 100644 --- a/src/gentropy/dataset/l2g_feature.py +++ b/src/gentropy/dataset/l2g_feature.py @@ -12,11 +12,25 @@ if TYPE_CHECKING: from pyspark.sql.types import StructType + from gentropy.dataset.study_locus import StudyLocus + @dataclass class L2GFeature(Dataset, ABC): """Locus-to-gene feature dataset.""" + input_dependency: Any = None + + @property + def input_dependency(self: L2GFeature) -> Any: + """Getter for the input_dependency.""" + return self._input_dependency + + @input_dependency.setter + def set_input_dependency(self: L2GFeature, value: Any) -> None: + """Setter for the input_dependency.""" + self._input_dependency = value + @classmethod def get_schema(cls: type[L2GFeature]) -> StructType: """Provides the schema for the L2GFeature dataset. diff --git a/src/gentropy/method/l2g/feature_factory.py b/src/gentropy/method/l2g/feature_factory.py index b31ff5c18..67cd17ede 100644 --- a/src/gentropy/method/l2g/feature_factory.py +++ b/src/gentropy/method/l2g/feature_factory.py @@ -4,7 +4,7 @@ from functools import reduce from itertools import chain -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Mapping import pyspark.sql.functions as f @@ -366,32 +366,79 @@ def compute( class DistanceTssMeanFeature(L2GFeature): """Average distance of all tagging variants to gene TSS.""" + # TODO: credible_set should be a property?? + + @classmethod + def dummy( + cls: type[DistanceTssMeanFeature], + _input_dependency: Any, + ): + cls._input_dependency = _input_dependency + @classmethod - def compute(cls: type[DistanceTssMeanFeature], input_dependency: V2G) -> L2GFeature: + def compute( + cls: type[DistanceTssMeanFeature], + input_dependency: Any, + credible_set: StudyLocus, + ) -> Any: """Computes the feature. - Args: - input_dependency (V2G): V2G dependency Returns: - L2GFeature: Feature dataset - Raises: - NotImplementedError: Not implemented + L2GFeature: Feature dataset """ - raise NotImplementedError + agg_expr = f.mean("weightedScore").alias("distanceTssMean") + # Start of common logic + v2g = input_dependency.df.filter(f.col("datasourceId") == "canonical_tss") + wide_df = ( + credible_set.df.withColumn("variantInLocus", f.explode_outer("locus")) + .select( + "studyLocusId", "variantInLocusId", "variantInLocusPosteriorProbability" + ) + .join( + v2g.selectExpr("variantId as variantInLocusId", "geneId", "score"), + on="variantInLocusId", + how="inner", + ) + .withColumn( + "weightedScore", + f.col("score") * f.col("variantInLocusPosteriorProbability"), + ) + .groupBy("studyLocusId", "geneId") + .agg(agg_expr) + ) + return DistanceTssMeanFeature( + _df=convert_from_wide_to_long( + wide_df, + id_vars=("studyLocusId", "geneId"), + var_name="featureName", + value_name="featureValue", + ), + _schema=L2GFeature.get_schema(), + ) class FeatureFactory: """Factory class for creating features.""" - # TODO: should this be live in the `features_list`? - feature_mapper = { - "distanceTssMinimum": DistanceTssMinimumFeature, + # TODO: should this live in the `features_list`? + feature_mapper: Mapping[str, type[L2GFeature]] = { + # "distanceTssMinimum": DistanceTssMinimumFeature, "distanceTssMean": DistanceTssMeanFeature, } + def __init__(self: type[FeatureFactory], credible_set: StudyLocus) -> None: + """Initializes the factory. + + Args: + credible_set (StudyLocus): credible sets to annotate + """ + self.credible_set = credible_set + @classmethod def generate_features( - cls: type[FeatureFactory], session: Session, features_list: list[dict[str, str]] + cls: type[FeatureFactory], + session: Session, + features_list: list[dict[str, str]], ) -> list[L2GFeature]: """Generates a feature matrix by reading an object with instructions on how to create the features. @@ -404,10 +451,13 @@ def generate_features( """ computed_features = [] for feature in features_list: - input_dependency = cls.inject_dependency(session, feature["path"]) - computed_features.append( - cls.compute_feature(feature["name"], input_dependency) - ) + if feature["name"] in cls.feature_mapper: + input_dependency = cls.inject_dependency(session, feature["path"]) + computed_features.append( + cls.compute_feature(feature["name"], input_dependency) + ) + else: + raise ValueError(f"Feature {feature['name']} not found.") return computed_features @classmethod @@ -422,7 +472,8 @@ def compute_feature( Returns: L2GFeature: instantiated feature object """ - return cls.feature_mapper[feature_name].compute(input_dependency) + feature_cls = cls.feature_mapper[feature_name] + return feature_cls.compute(input_dependency) @classmethod def inject_dependency( @@ -436,4 +487,5 @@ def inject_dependency( Returns: Any: dependency object """ + # TODO: Dependency injection feature responsability? return V2G.from_parquet(session, feature_dependency_path) From c690ffcc3608d63d3ba89c04f572d50fcb6f5b3b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Irene=20L=C3=B3pez?= Date: Fri, 6 Sep 2024 14:16:37 +0100 Subject: [PATCH 05/48] chore(l2gfeature): make credible_set and input_dependency instance attributes --- src/gentropy/dataset/l2g_feature.py | 25 ++++++++++++++----------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/src/gentropy/dataset/l2g_feature.py b/src/gentropy/dataset/l2g_feature.py index 4d47873f4..cbcd2bec5 100644 --- a/src/gentropy/dataset/l2g_feature.py +++ b/src/gentropy/dataset/l2g_feature.py @@ -19,17 +19,20 @@ class L2GFeature(Dataset, ABC): """Locus-to-gene feature dataset.""" - input_dependency: Any = None - - @property - def input_dependency(self: L2GFeature) -> Any: - """Getter for the input_dependency.""" - return self._input_dependency - - @input_dependency.setter - def set_input_dependency(self: L2GFeature, value: Any) -> None: - """Setter for the input_dependency.""" - self._input_dependency = value + def __init__( + self: L2GFeature, + input_dependency: Any = None, + credible_set: StudyLocus | None = None, + ) -> None: + """Initializes a L2GFeature dataset. + + Args: + input_dependency (Any): The dependency that the L2GFeature dataset depends on. Defaults to None. + credible_set (StudyLocus | None): The credible set that the L2GFeature dataset is based on. Defaults to None. + """ + super().__init__() + self.input_dependency = input_dependency + self.credible_set = credible_set @classmethod def get_schema(cls: type[L2GFeature]) -> StructType: From a54e694127cd0b831cfd4c4f71543350cfea43bc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Irene=20L=C3=B3pez?= Date: Fri, 6 Sep 2024 14:16:59 +0100 Subject: [PATCH 06/48] chore(l2gfeature): make credible_set and input_dependency instance attributes --- src/gentropy/dataset/l2g_feature.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/gentropy/dataset/l2g_feature.py b/src/gentropy/dataset/l2g_feature.py index cbcd2bec5..6a426b169 100644 --- a/src/gentropy/dataset/l2g_feature.py +++ b/src/gentropy/dataset/l2g_feature.py @@ -19,7 +19,7 @@ class L2GFeature(Dataset, ABC): """Locus-to-gene feature dataset.""" - def __init__( + def __post_init__( self: L2GFeature, input_dependency: Any = None, credible_set: StudyLocus | None = None, @@ -30,7 +30,7 @@ def __init__( input_dependency (Any): The dependency that the L2GFeature dataset depends on. Defaults to None. credible_set (StudyLocus | None): The credible set that the L2GFeature dataset is based on. Defaults to None. """ - super().__init__() + super().__post_init__() self.input_dependency = input_dependency self.credible_set = credible_set From 85a7bf40f86eee2c834d762620845da90af0c4f1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Irene=20L=C3=B3pez?= Date: Fri, 6 Sep 2024 14:17:38 +0100 Subject: [PATCH 07/48] chore(featurefactory): distanceTssMeanFeature working --- src/gentropy/method/l2g/feature_factory.py | 23 ++++++++++++++-------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/src/gentropy/method/l2g/feature_factory.py b/src/gentropy/method/l2g/feature_factory.py index 67cd17ede..ecc89a9f2 100644 --- a/src/gentropy/method/l2g/feature_factory.py +++ b/src/gentropy/method/l2g/feature_factory.py @@ -366,19 +366,22 @@ def compute( class DistanceTssMeanFeature(L2GFeature): """Average distance of all tagging variants to gene TSS.""" - # TODO: credible_set should be a property?? + fill_na_value = 500_000 @classmethod def dummy( cls: type[DistanceTssMeanFeature], - _input_dependency: Any, + input_dependency: Any, + credible_set: StudyLocus, ): - cls._input_dependency = _input_dependency + cls.input_dependency = input_dependency + cls.credible_set = credible_set + return cls @classmethod def compute( cls: type[DistanceTssMeanFeature], - input_dependency: Any, + input_dependency: V2G, credible_set: StudyLocus, ) -> Any: """Computes the feature. @@ -387,12 +390,16 @@ def compute( L2GFeature: Feature dataset """ agg_expr = f.mean("weightedScore").alias("distanceTssMean") - # Start of common logic + # Everything but expresion is common logic v2g = input_dependency.df.filter(f.col("datasourceId") == "canonical_tss") wide_df = ( credible_set.df.withColumn("variantInLocus", f.explode_outer("locus")) .select( - "studyLocusId", "variantInLocusId", "variantInLocusPosteriorProbability" + "studyLocusId", + f.col("variantInLocus.variantId").alias("variantInLocusId"), + f.col("variantInLocus.posteriorProbability").alias( + "variantInLocusPosteriorProbability" + ), ) .join( v2g.selectExpr("variantId as variantInLocusId", "geneId", "score"), @@ -406,14 +413,14 @@ def compute( .groupBy("studyLocusId", "geneId") .agg(agg_expr) ) - return DistanceTssMeanFeature( + return cls( _df=convert_from_wide_to_long( wide_df, id_vars=("studyLocusId", "geneId"), var_name="featureName", value_name="featureValue", ), - _schema=L2GFeature.get_schema(), + _schema=cls.get_schema(), ) From d24de6d806bb3127a7280da749382cb8a9489cff Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Irene=20L=C3=B3pez?= Date: Mon, 9 Sep 2024 09:27:53 +0200 Subject: [PATCH 08/48] refactor(l2g): improve step dependency management --- src/gentropy/l2g.py | 99 +++++++++++++++++++++++---------------------- 1 file changed, 51 insertions(+), 48 deletions(-) diff --git a/src/gentropy/l2g.py b/src/gentropy/l2g.py index 3ddacb81c..5d5337db3 100644 --- a/src/gentropy/l2g.py +++ b/src/gentropy/l2g.py @@ -28,41 +28,42 @@ class LocusToGeneStep: def __init__( self, session: Session, - run_mode: str, - predictions_path: str, - credible_set_path: str, - variant_gene_path: str, - colocalisation_path: str, - study_index_path: str, - gold_standard_curation_path: str, - gene_interactions_path: str, - features_list: list[dict[str, str]], hyperparameters: dict[str, Any], + *, + run_mode: str, + features_list: list[str], download_from_hub: bool, - model_path: str | None, + wandb_run_name: str, + model_path: str | None = None, + credible_set_path: str, + gold_standard_curation_path: str | None = None, + variant_gene_path: str | None = None, + colocalisation_path: str | None = None, + study_index_path: str | None = None, + gene_interactions_path: str | None = None, + predictions_path: str | None = None, feature_matrix_path: str | None = None, - wandb_run_name: str | None = None, hf_hub_repo_id: str | None = LocusToGeneConfig().hf_hub_repo_id, ) -> None: """Initialise the step and run the logic based on mode. Args: session (Session): Session object that contains the Spark session - run_mode (str): Run mode, either 'train' or 'predict' - predictions_path (str): Path to save the predictions - credible_set_path (str): Path to the credible set dataset - variant_gene_path (str): Path to the variant to gene dataset - colocalisation_path (str): Path to the colocalisation dataset - study_index_path (str): Path to the study index dataset - gold_standard_curation_path (str): Path to the gold standard curation dataset - gene_interactions_path (str): Path to the gene interactions dataset - features_list (list[dict[str, str]]): List of features to use for the model. It is a list of objects with 2 keys: 'name' and 'path'. hyperparameters (dict[str, Any]): Hyperparameters for the model - download_from_hub (bool): Whether to download the model from the Hugging Face Hub - model_path (str | None): Path to the fitted model - feature_matrix_path (str | None): Path to save the feature matrix. Defaults to None. - wandb_run_name (str | None): Name of the wandb run. Defaults to None. - hf_hub_repo_id (str | None): Hugging Face Hub repo id. Defaults to the one set in the step configuration. + run_mode (str): Run mode, either 'train' or 'predict' + features_list (list[str]): List of features to use for the model + download_from_hub (bool): Whether to download the model from Hugging Face Hub + wandb_run_name (str): Name of the run to track model training in Weights and Biases + model_path (str | None): Path to the model. It can be either in the filesystem or the name on the Hugging Face Hub (in the form of username/repo_name). + credible_set_path (str): Path to the credible set dataset necessary to build the feature matrix + gold_standard_curation_path (str | None): Path to the gold standard curation file + variant_gene_path (str | None): Path to the variant-gene dataset + colocalisation_path (str | None): Path to the colocalisation dataset + study_index_path (str | None): Path to the study index dataset + gene_interactions_path (str | None): Path to the gene interactions dataset + predictions_path (str | None): Path to the L2G predictions output dataset + feature_matrix_path (str | None): Path to the L2G feature matrix output dataset + hf_hub_repo_id (str | None): Hugging Face Hub repository ID. If provided, the model will be uploaded to Hugging Face. Raises: ValueError: If run_mode is not 'train' or 'predict' @@ -76,12 +77,6 @@ def __init__( self.run_mode = run_mode self.model_path = model_path self.predictions_path = predictions_path - self.credible_set_path = credible_set_path - self.variant_gene_path = variant_gene_path - self.colocalisation_path = colocalisation_path - self.study_index_path = study_index_path - self.gold_standard_curation_path = gold_standard_curation_path - self.gene_interactions_path = gene_interactions_path self.features_list = list(features_list) self.hyperparameters = dict(hyperparameters) self.feature_matrix_path = feature_matrix_path @@ -93,17 +88,31 @@ def __init__( self.credible_set = StudyLocus.from_parquet( session, credible_set_path, recursiveFileLookup=True ) - self.studies = StudyIndex.from_parquet( - session, study_index_path, recursiveFileLookup=True + self.studies = ( + StudyIndex.from_parquet(session, study_index_path, recursiveFileLookup=True) + if study_index_path + else None ) - self.v2g = V2G.from_parquet(session, variant_gene_path) - self.coloc = Colocalisation.from_parquet( - session, colocalisation_path, recursiveFileLookup=True + self.v2g = ( + V2G.from_parquet(session, variant_gene_path) if variant_gene_path else None + ) + self.coloc = ( + Colocalisation.from_parquet( + session, colocalisation_path, recursiveFileLookup=True + ) + if colocalisation_path + else None ) if run_mode == "predict": + if not self.studies and self.v2g and self.coloc: + raise ValueError("Dependencies for predict mode not set.") self.run_predict() elif run_mode == "train": + if not gold_standard_curation_path and gene_interactions_path: + raise ValueError("Dependencies for train mode not set.") + self.gs_curation = self.session.spark.read.json(gold_standard_curation_path) + self.interactions = self.session.spark.read.parquet(gene_interactions_path) self.run_train() def run_predict(self) -> None: @@ -114,6 +123,7 @@ def run_predict(self) -> None: """ if not self.predictions_path: raise ValueError("predictions_path must be set for predict mode.") + # TODO: IMPROVE - it is not correct that L2GPrediction outputs a feature matrix - FM should be written when training predictions, feature_matrix = L2GPrediction.from_credible_set( self.features_list, self.credible_set, @@ -140,14 +150,9 @@ def run_train(self) -> None: Raises: ValueError: If gold_standard_curation_path, gene_interactions_path, or wandb_run_name are not set. """ - if not ( - self.gold_standard_curation_path - and self.gene_interactions_path - and self.wandb_run_name - and self.model_path - ): + if not (self.wandb_run_name and self.model_path): raise ValueError( - "gold_standard_curation_path, gene_interactions_path, and wandb_run_name, and a path to save the model must be set for train mode." + "wandb_run_name, and a path to save the model must be set for train mode." ) wandb_key = access_gcp_secret("wandb-key", "open-targets-genetics-dev") @@ -186,12 +191,10 @@ def _generate_feature_matrix(self) -> L2GFeatureMatrix: Returns: L2GFeatureMatrix: Feature matrix with gold standards annotated with features. """ - gs_curation = self.session.spark.read.json(self.gold_standard_curation_path) - interactions = self.session.spark.read.parquet(self.gene_interactions_path) study_locus_overlap = StudyLocus( _df=self.credible_set.df.join( f.broadcast( - gs_curation.select( + self.gs_curation.select( StudyLocus.assign_study_locus_id( f.col("association_info.otg_id"), # studyId f.concat_ws( # variantId @@ -211,10 +214,10 @@ def _generate_feature_matrix(self) -> L2GFeatureMatrix: ).find_overlaps(self.studies) gold_standards = L2GGoldStandard.from_otg_curation( - gold_standard_curation=gs_curation, + gold_standard_curation=self.gs_curation, v2g=self.v2g, study_locus_overlap=study_locus_overlap, - interactions=interactions, + interactions=self.interactions, ) # TODO: Should StudyLocus and GoldStandard have an `annotate_w_features` method? From 6a3af694e99416234105766e9ec54c74a5d6ff81 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Irene=20L=C3=B3pez?= Date: Mon, 9 Sep 2024 11:07:25 +0200 Subject: [PATCH 09/48] feat: implement --- src/gentropy/dataset/l2g_feature.py | 13 +- src/gentropy/dataset/l2g_feature_matrix.py | 12 +- src/gentropy/l2g.py | 11 +- src/gentropy/method/l2g/feature_factory.py | 405 +++------------------ 4 files changed, 69 insertions(+), 372 deletions(-) diff --git a/src/gentropy/dataset/l2g_feature.py b/src/gentropy/dataset/l2g_feature.py index 6a426b169..c46d2cb0f 100644 --- a/src/gentropy/dataset/l2g_feature.py +++ b/src/gentropy/dataset/l2g_feature.py @@ -21,17 +21,17 @@ class L2GFeature(Dataset, ABC): def __post_init__( self: L2GFeature, - input_dependency: Any = None, + feature_dependency: Any = None, credible_set: StudyLocus | None = None, ) -> None: """Initializes a L2GFeature dataset. Args: - input_dependency (Any): The dependency that the L2GFeature dataset depends on. Defaults to None. + feature_dependency (Any): The dependency that the L2GFeature dataset depends on. Defaults to None. credible_set (StudyLocus | None): The credible set that the L2GFeature dataset is based on. Defaults to None. """ super().__post_init__() - self.input_dependency = input_dependency + self.feature_dependency = feature_dependency self.credible_set = credible_set @classmethod @@ -45,9 +45,14 @@ def get_schema(cls: type[L2GFeature]) -> StructType: @classmethod @abstractmethod - def compute(cls: type[L2GFeature]) -> L2GFeature: + def compute( + cls: type[L2GFeature], credible_set: StudyLocus, feature_dependency: Any + ) -> L2GFeature: """Computes the L2GFeature dataset. + Args: + credible_set (StudyLocus): The credible set that will be used for annotation + feature_dependency (Any): The dependency that the L2GFeature class needs to compute the feature Returns: L2GFeature: a L2GFeature dataset """ diff --git a/src/gentropy/dataset/l2g_feature_matrix.py b/src/gentropy/dataset/l2g_feature_matrix.py index c063e02dc..5640f513c 100644 --- a/src/gentropy/dataset/l2g_feature_matrix.py +++ b/src/gentropy/dataset/l2g_feature_matrix.py @@ -6,7 +6,7 @@ from typing import TYPE_CHECKING, Type from gentropy.common.spark_helpers import convert_from_long_to_wide -from gentropy.method.l2g.feature_factory import FeatureFactory +from gentropy.method.l2g.feature_factory import FeatureFactory, L2GFeatureInputLoader if TYPE_CHECKING: from pyspark.sql import DataFrame @@ -55,13 +55,15 @@ def __init__( def from_features_list( cls: Type[L2GFeatureMatrix], session: Session, - features_list: list[dict[str, str]], + features_list: list[str], + features_input_loader: L2GFeatureInputLoader, ) -> L2GFeatureMatrix: """Generate features from the gentropy datasets by calling the feature factory that will instantiate the corresponding features. Args: session (Session): Session object - features_list (list[dict[str, str]]): List of objects with 2 keys corresponding to the features to generate: 'name' and 'path'. + features_list (list[str]): List of objects with 2 keys corresponding to the features to generate: 'name' and 'path'. + features_input_loader (L2GFeatureInputLoader): Object that contais features input. Returns: L2GFeatureMatrix: L2G feature matrix dataset @@ -71,7 +73,9 @@ def from_features_list( [ # Compute all features and merge them into a single dataframe feature.df - for feature in FeatureFactory.generate_features(session, features_list) + for feature in FeatureFactory.generate_features( + session, features_list, features_input_loader + ) ], ) return cls( diff --git a/src/gentropy/l2g.py b/src/gentropy/l2g.py index 5d5337db3..84c738a7b 100644 --- a/src/gentropy/l2g.py +++ b/src/gentropy/l2g.py @@ -12,7 +12,7 @@ from gentropy.common.utils import access_gcp_secret from gentropy.config import LocusToGeneConfig from gentropy.dataset.colocalisation import Colocalisation -from gentropy.dataset.l2g_feature_matrix import L2GFeatureMatrix +from gentropy.dataset.l2g_feature_matrix import L2GFeatureInputLoader, L2GFeatureMatrix from gentropy.dataset.l2g_gold_standard import L2GGoldStandard from gentropy.dataset.l2g_prediction import L2GPrediction from gentropy.dataset.study_index import StudyIndex @@ -103,6 +103,11 @@ def __init__( if colocalisation_path else None ) + self.features_input_loader = L2GFeatureInputLoader( + v2g=self.v2g, + coloc=self.coloc, + studies=self.studies, + ) if run_mode == "predict": if not self.studies and self.v2g and self.coloc: @@ -125,6 +130,7 @@ def run_predict(self) -> None: raise ValueError("predictions_path must be set for predict mode.") # TODO: IMPROVE - it is not correct that L2GPrediction outputs a feature matrix - FM should be written when training predictions, feature_matrix = L2GPrediction.from_credible_set( + # TODO: rewrite this function to use the new FM generation self.features_list, self.credible_set, self.studies, @@ -222,8 +228,7 @@ def _generate_feature_matrix(self) -> L2GFeatureMatrix: # TODO: Should StudyLocus and GoldStandard have an `annotate_w_features` method? fm = L2GFeatureMatrix.from_features_list( - self.session, - self.features_list, + self.session, self.features_list, self.features_input_loader ) return ( diff --git a/src/gentropy/method/l2g/feature_factory.py b/src/gentropy/method/l2g/feature_factory.py index ecc89a9f2..68e23876f 100644 --- a/src/gentropy/method/l2g/feature_factory.py +++ b/src/gentropy/method/l2g/feature_factory.py @@ -2,344 +2,45 @@ from __future__ import annotations -from functools import reduce -from itertools import chain from typing import TYPE_CHECKING, Any, Mapping import pyspark.sql.functions as f from gentropy.common.session import Session -from gentropy.common.spark_helpers import ( - convert_from_wide_to_long, - get_record_with_maximum_value, -) +from gentropy.common.spark_helpers import convert_from_wide_to_long from gentropy.dataset.l2g_feature import L2GFeature -from gentropy.dataset.study_locus import CredibleInterval, StudyLocus -from gentropy.method.colocalisation import Coloc, ECaviar +from gentropy.dataset.study_locus import StudyLocus if TYPE_CHECKING: - from pyspark.sql import Column, DataFrame - - from gentropy.dataset.colocalisation import Colocalisation - from gentropy.dataset.study_index import StudyIndex from gentropy.dataset.v2g import V2G -class ColocalisationFactory: - """Feature extraction in colocalisation.""" - - @classmethod - def _add_colocalisation_metric(cls: type[ColocalisationFactory]) -> Column: - """Expression that adds a `colocalisationMetric` column to the colocalisation dataframe in preparation for feature extraction. - - Returns: - Column: The expression that adds a `colocalisationMetric` column with the derived metric - """ - method_metric_map = { - ECaviar.METHOD_NAME: ECaviar.METHOD_METRIC, - Coloc.METHOD_NAME: Coloc.METHOD_METRIC, - } - map_expr = f.create_map(*[f.lit(x) for x in chain(*method_metric_map.items())]) - return map_expr[f.col("colocalisationMethod")].alias("colocalisationMetric") - - @staticmethod - def _get_max_coloc_per_credible_set( - colocalisation: Colocalisation, - credible_set: StudyLocus, - studies: StudyIndex, - ) -> L2GFeature: - """Get the maximum colocalisation posterior probability for each pair of overlapping study-locus per type of colocalisation method and QTL type. - - Args: - colocalisation (Colocalisation): Colocalisation dataset - credible_set (StudyLocus): Study locus dataset - studies (StudyIndex): Study index dataset - - Returns: - L2GFeature: Stores the features with the max coloc probabilities for each pair of study-locus - """ - colocalisation_df = colocalisation.df.select( - f.col("leftStudyLocusId").alias("studyLocusId"), - "rightStudyLocusId", - f.coalesce("h4", "clpp").alias("score"), - ColocalisationFactory._add_colocalisation_metric(), - ) - - colocalising_credible_sets = ( - credible_set.df.select("studyLocusId", "studyId") - # annotate studyLoci with overlapping IDs on the left - to just keep GWAS associations - .join( - colocalisation_df, - on="studyLocusId", - how="inner", - ) - # bring study metadata to just keep QTL studies on the right - .join( - credible_set.df.join( - studies.df.select("studyId", "studyType", "geneId"), "studyId" - ).selectExpr( - "studyLocusId as rightStudyLocusId", - "studyType as right_studyType", - "geneId", - ), - on="rightStudyLocusId", - how="inner", - ) - .filter(f.col("right_studyType") != "gwas") - .select( - "studyLocusId", - "right_studyType", - "geneId", - "score", - "colocalisationMetric", - ) - ) - - # Max PP calculation per credible set AND type of QTL AND colocalisation method - local_max = ( - get_record_with_maximum_value( - colocalising_credible_sets, - ["studyLocusId", "right_studyType", "geneId", "colocalisationMetric"], - "score", - ) - .select( - "*", - f.col("score").alias("max_score"), - f.lit("Local").alias("score_type"), - ) - .drop("score") - ) - - neighbourhood_max = ( - local_max.selectExpr( - "studyLocusId", "max_score as local_max_score", "geneId" - ) - .join( - # Add maximum in the neighborhood - get_record_with_maximum_value( - colocalising_credible_sets.withColumnRenamed( - "score", "tmp_nbh_max_score" - ), - ["studyLocusId", "right_studyType", "colocalisationMetric"], - "tmp_nbh_max_score", - ).drop("geneId"), - on="studyLocusId", - ) - .withColumn("score_type", f.lit("Neighborhood")) - .withColumn( - "max_score", - f.log10( - f.abs( - f.col("local_max_score") - - f.col("tmp_nbh_max_score") - + f.lit(0.0001) # intercept - ) - ), - ) - ).drop("tmp_nbh_max_score", "local_max_score") - - return L2GFeature( - _df=( - # Combine local and neighborhood metrics - local_max.unionByName( - neighbourhood_max, allowMissingColumns=True - ).select( - "studyLocusId", - "geneId", - # Feature name is a concatenation of the QTL type, colocalisation metric and if it's local or in the vicinity - f.concat_ws( - "", - f.col("right_studyType"), - f.lit("Coloc"), - f.initcap(f.col("colocalisationMetric")), - f.lit("Maximum"), - f.regexp_replace(f.col("score_type"), "Local", ""), - ).alias("featureName"), - f.col("max_score").cast("float").alias("featureValue"), - ) - ), - _schema=L2GFeature.get_schema(), - ) - +class L2GFeatureInputLoader: + """Loads all input datasets required for the L2GFeature dataset.""" -class StudyLocusFactory(StudyLocus): - """Feature extraction in study locus.""" + def __init__( + self, + **kwargs: dict[str, Any], + ) -> None: + """Initializes L2GFeatureInputLoader with the provided inputs and returns loaded dependencies as a list.""" + self.input_dependencies = [v for v in kwargs.values() if v is not None] - @staticmethod - def _get_tss_distance_features(credible_set: StudyLocus, v2g: V2G) -> L2GFeature: - """Joins StudyLocus with the V2G to extract a score that is based on the distance to a gene TSS of any variant weighted by its posterior probability in a credible set. + def get_dependency(self, dependency_type: Any) -> Any: + """Returns the dependency that matches the provided type.""" + for dependency in self.input_dependencies: + if isinstance(dependency, dependency_type): + return dependency - Args: - credible_set (StudyLocus): Credible set dataset - v2g (V2G): Dataframe containing the distances of all variants to all genes TSS within a region + def __iter__(self) -> list[Any]: + """Make the class iterable, returning the input dependencies list.""" + return iter(self.input_dependencies) - Returns: - L2GFeature: Stores the features with the score of weighting the distance to the TSS by the posterior probability of the variant + def __repr__(self) -> str: + """Return a string representation of the input dependencies. + Useful for understanding the loader content without having to print the object attribute. """ - wide_df = ( - credible_set.filter_credible_set(CredibleInterval.IS95) - .df.withColumn("variantInLocus", f.explode_outer("locus")) - .select( - "studyLocusId", - "variantId", - f.col("variantInLocus.variantId").alias("variantInLocusId"), - f.col("variantInLocus.posteriorProbability").alias( - "variantInLocusPosteriorProbability" - ), - ) - .join( - v2g.df.filter(f.col("datasourceId") == "canonical_tss").selectExpr( - "variantId as variantInLocusId", "geneId", "score" - ), - on="variantInLocusId", - how="inner", - ) - .withColumn( - "weightedScore", - f.col("score") * f.col("variantInLocusPosteriorProbability"), - ) - .groupBy("studyLocusId", "geneId") - .agg( - f.min("weightedScore").alias("distanceTssMinimum"), - f.mean("weightedScore").alias("distanceTssMean"), - ) - ) - - return L2GFeature( - _df=convert_from_wide_to_long( - wide_df, - id_vars=("studyLocusId", "geneId"), - var_name="featureName", - value_name="featureValue", - ), - _schema=L2GFeature.get_schema(), - ) - - @staticmethod - def _get_vep_features( - credible_set: StudyLocus, - v2g: V2G, - ) -> L2GFeature: - """Get the maximum VEP score for all variants in a locus's 95% credible set. - - This informs about functional impact of the variants in the locus. For more information on variant consequences, see: https://www.ensembl.org/info/genome/variation/prediction/predicted_data.html - Two metrics: max VEP score per study locus and gene, and max VEP score per study locus. - - - Args: - credible_set (StudyLocus): Study locus dataset with the associations to be annotated - v2g (V2G): V2G dataset with the variant/gene relationships and their consequences - - Returns: - L2GFeature: Stores the features with the max VEP score. - """ - - def _aggregate_vep_feature( - df: DataFrame, - aggregation_expr: Column, - aggregation_cols: list[str], - feature_name: str, - ) -> DataFrame: - """Extracts the maximum or average VEP score after grouping by the given columns. Different aggregations return different predictive annotations. - - If the group_cols include "geneId", the maximum/mean VEP score per gene is returned. - Otherwise, the maximum/mean VEP score for all genes in the neighborhood of the locus is returned. - - Args: - df (DataFrame): DataFrame with the VEP scores for each variant in a studyLocus - aggregation_expr (Column): Aggregation expression to apply - aggregation_cols (list[str]): Columns to group by - feature_name (str): Name of the feature to be returned - - Returns: - DataFrame: DataFrame with the maximum VEP score per locus or per locus/gene - """ - if "geneId" in aggregation_cols: - return df.groupBy(aggregation_cols).agg( - aggregation_expr.alias(feature_name) - ) - return ( - df.groupBy(aggregation_cols) - .agg( - aggregation_expr.alias(feature_name), - f.collect_set("geneId").alias("geneId"), - ) - .withColumn("geneId", f.explode("geneId")) - ) - - credible_set_w_variant_consequences = ( - credible_set.filter_credible_set(CredibleInterval.IS95) - .df.withColumn("variantInLocus", f.explode_outer("locus")) - .select( - f.col("studyLocusId"), - f.col("variantId"), - f.col("studyId"), - f.col("variantInLocus.variantId").alias("variantInLocusId"), - f.col("variantInLocus.posteriorProbability").alias( - "variantInLocusPosteriorProbability" - ), - ) - .join( - # Join with V2G to get variant consequences - v2g.df.filter(f.col("datasourceId") == "variantConsequence").selectExpr( - "variantId as variantInLocusId", "geneId", "score" - ), - on="variantInLocusId", - ) - .select( - "studyLocusId", - "variantId", - "studyId", - "geneId", - (f.col("score") * f.col("variantInLocusPosteriorProbability")).alias( - "weightedScore" - ), - ) - .distinct() - ) - - return L2GFeature( - _df=convert_from_wide_to_long( - reduce( - lambda x, y: x.unionByName(y, allowMissingColumns=True), - [ - # Calculate overall max VEP score for all genes in the vicinity - credible_set_w_variant_consequences.transform( - _aggregate_vep_feature, - f.max("weightedScore"), - ["studyLocusId"], - "vepMaximumNeighborhood", - ), - # Calculate overall max VEP score per gene - credible_set_w_variant_consequences.transform( - _aggregate_vep_feature, - f.max("weightedScore"), - ["studyLocusId", "geneId"], - "vepMaximum", - ), - # Calculate mean VEP score for all genes in the vicinity - credible_set_w_variant_consequences.transform( - _aggregate_vep_feature, - f.mean("weightedScore"), - ["studyLocusId"], - "vepMeanNeighborhood", - ), - # Calculate mean VEP score per gene - credible_set_w_variant_consequences.transform( - _aggregate_vep_feature, - f.mean("weightedScore"), - ["studyLocusId", "geneId"], - "vepMean", - ), - ], - ), - id_vars=("studyLocusId", "geneId"), - var_name="featureName", - value_name="featureValue", - ).filter(f.col("featureValue").isNotNull()), - _schema=L2GFeature.get_schema(), - ) + return repr(self.input_dependencies) class DistanceTssMinimumFeature(L2GFeature): @@ -367,31 +68,26 @@ class DistanceTssMeanFeature(L2GFeature): """Average distance of all tagging variants to gene TSS.""" fill_na_value = 500_000 - - @classmethod - def dummy( - cls: type[DistanceTssMeanFeature], - input_dependency: Any, - credible_set: StudyLocus, - ): - cls.input_dependency = input_dependency - cls.credible_set = credible_set - return cls + feature_dependency = V2G @classmethod def compute( cls: type[DistanceTssMeanFeature], - input_dependency: V2G, credible_set: StudyLocus, + feature_dependency: V2G, ) -> Any: """Computes the feature. + Args: + credible_set (StudyLocus): Credible set dependency + feature_dependency (V2G): Dataset that contains the distance information + Returns: L2GFeature: Feature dataset """ agg_expr = f.mean("weightedScore").alias("distanceTssMean") # Everything but expresion is common logic - v2g = input_dependency.df.filter(f.col("datasourceId") == "canonical_tss") + v2g = feature_dependency.df.filter(f.col("datasourceId") == "canonical_tss") wide_df = ( credible_set.df.withColumn("variantInLocus", f.explode_outer("locus")) .select( @@ -427,13 +123,13 @@ def compute( class FeatureFactory: """Factory class for creating features.""" - # TODO: should this live in the `features_list`? feature_mapper: Mapping[str, type[L2GFeature]] = { # "distanceTssMinimum": DistanceTssMinimumFeature, "distanceTssMean": DistanceTssMeanFeature, } + features_input_loader: L2GFeatureInputLoader - def __init__(self: type[FeatureFactory], credible_set: StudyLocus) -> None: + def __init__(self: FeatureFactory, credible_set: StudyLocus) -> None: """Initializes the factory. Args: @@ -446,53 +142,40 @@ def generate_features( cls: type[FeatureFactory], session: Session, features_list: list[dict[str, str]], + credible_set_path: str, + features_input_loader: L2GFeatureInputLoader, ) -> list[L2GFeature]: """Generates a feature matrix by reading an object with instructions on how to create the features. Args: session (Session): session object features_list (list[dict[str, str]]): list of objects with 2 keys: 'name' and 'path'. + credible_set_path (str | None): path to credible set parquet file. + features_input_loader (L2GFeatureInputLoader): object with required features dependencies. Returns: list[L2GFeature]: list of computed features. """ + cls.features_input_loader = features_input_loader computed_features = [] for feature in features_list: if feature["name"] in cls.feature_mapper: - input_dependency = cls.inject_dependency(session, feature["path"]) - computed_features.append( - cls.compute_feature(feature["name"], input_dependency) - ) + computed_features.append(cls.compute_feature(feature["name"])) else: raise ValueError(f"Feature {feature['name']} not found.") return computed_features - @classmethod - def compute_feature( - cls: type[FeatureFactory], feature_name: str, input_dependency: Any - ) -> L2GFeature: + def compute_feature(self: FeatureFactory, feature_name: str) -> L2GFeature: """Instantiates feature class. Args: feature_name (str): name of the feature - input_dependency (Any): dependency object - Returns: - L2GFeature: instantiated feature object - """ - feature_cls = cls.feature_mapper[feature_name] - return feature_cls.compute(input_dependency) - - @classmethod - def inject_dependency( - cls: type[FeatureFactory], session: Session, feature_dependency_path: str - ) -> Any: - """Injects a dependency into the feature factory. - Args: - session (Session): session object - feature_dependency_path (str): path to the dependency of the feature Returns: - Any: dependency object + L2GFeature: instantiated feature object """ - # TODO: Dependency injection feature responsability? - return V2G.from_parquet(session, feature_dependency_path) + feature_cls = self.feature_mapper[feature_name] + # Filter features_input_loader to pass only the dependency that the feature needs + feature_input_type = feature_cls.feature_dependency + feature_input = cls.features_input_loader.get_dependency(feature_input_type) + return feature_cls.compute(feature_input) From 09d5291008a598153409d852eaf69a7c75914127 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Irene=20L=C3=B3pez?= Date: Mon, 9 Sep 2024 16:27:33 +0200 Subject: [PATCH 10/48] chore: fix mypy issues --- config/step/ot_locus_to_gene_train.yaml | 6 +- src/gentropy/dataset/l2g_feature_matrix.py | 11 +- src/gentropy/dataset/l2g_prediction.py | 34 ++-- src/gentropy/l2g.py | 226 +++++++++++---------- src/gentropy/method/l2g/feature_factory.py | 74 ++++--- 5 files changed, 191 insertions(+), 160 deletions(-) diff --git a/config/step/ot_locus_to_gene_train.yaml b/config/step/ot_locus_to_gene_train.yaml index c6d1d7ffd..cf6d0a574 100644 --- a/config/step/ot_locus_to_gene_train.yaml +++ b/config/step/ot_locus_to_gene_train.yaml @@ -18,9 +18,7 @@ hyperparameters: loss: log_loss download_from_hub: true features_list: -- name: distanceTssMean # average distance of all tagging variants to gene TSS - path: ${datasets.variant_to_gene} --name: distanceTssMinimum + - distanceTssMean # minimum distance of all tagging variants to gene TSS - path: ${datasets.variant_to_gene} + - distanceTssMinimum diff --git a/src/gentropy/dataset/l2g_feature_matrix.py b/src/gentropy/dataset/l2g_feature_matrix.py index 5640f513c..3722ca15b 100644 --- a/src/gentropy/dataset/l2g_feature_matrix.py +++ b/src/gentropy/dataset/l2g_feature_matrix.py @@ -12,6 +12,7 @@ from pyspark.sql import DataFrame from gentropy.common.session import Session + from gentropy.dataset.study_locus import StudyLocus class L2GFeatureMatrix: @@ -55,6 +56,7 @@ def __init__( def from_features_list( cls: Type[L2GFeatureMatrix], session: Session, + credible_set: StudyLocus, features_list: list[str], features_input_loader: L2GFeatureInputLoader, ) -> L2GFeatureMatrix: @@ -62,7 +64,8 @@ def from_features_list( Args: session (Session): Session object - features_list (list[str]): List of objects with 2 keys corresponding to the features to generate: 'name' and 'path'. + credible_set (StudyLocus): Credible set of study locus pairs to annotate + features_list (list[str]): List of feature names to be computed. features_input_loader (L2GFeatureInputLoader): Object that contais features input. Returns: @@ -73,9 +76,9 @@ def from_features_list( [ # Compute all features and merge them into a single dataframe feature.df - for feature in FeatureFactory.generate_features( - session, features_list, features_input_loader - ) + for feature in FeatureFactory( + credible_set, features_list + ).generate_features(session, features_input_loader) ], ) return cls( diff --git a/src/gentropy/dataset/l2g_prediction.py b/src/gentropy/dataset/l2g_prediction.py index 724ada584..3a88ba1c1 100644 --- a/src/gentropy/dataset/l2g_prediction.py +++ b/src/gentropy/dataset/l2g_prediction.py @@ -7,12 +7,11 @@ from gentropy.common.schemas import parse_spark_schema from gentropy.common.session import Session -from gentropy.dataset.colocalisation import Colocalisation from gentropy.dataset.dataset import Dataset from gentropy.dataset.l2g_feature_matrix import L2GFeatureMatrix from gentropy.dataset.study_index import StudyIndex from gentropy.dataset.study_locus import StudyLocus -from gentropy.dataset.v2g import V2G +from gentropy.method.l2g.feature_factory import L2GFeatureInputLoader from gentropy.method.l2g.model import LocusToGeneModel if TYPE_CHECKING: @@ -40,12 +39,10 @@ def get_schema(cls: type[L2GPrediction]) -> StructType: @classmethod def from_credible_set( cls: Type[L2GPrediction], - features_list: list[str], - credible_set: StudyLocus, - study_index: StudyIndex, - v2g: V2G, - coloc: Colocalisation, session: Session, + credible_set: StudyLocus, + features_list: list[str], + features_input_loader: L2GFeatureInputLoader, model_path: str | None, hf_token: str | None = None, download_from_hub: bool = True, @@ -53,12 +50,10 @@ def from_credible_set( """Extract L2G predictions for a set of credible sets derived from GWAS. Args: - features_list (list[str]): List of features to use for the model - credible_set (StudyLocus): Credible set dataset - study_index (StudyIndex): Study index dataset - v2g (V2G): Variant to gene dataset - coloc (Colocalisation): Colocalisation dataset session (Session): Session object that contains the Spark session + credible_set (StudyLocus): Credible set dataset + features_list (list[str]): List of features to use for the model + features_input_loader (L2GFeatureInputLoader): Loader with all feature dependencies model_path (str | None): Path to the model file. It can be either in the filesystem or the name on the Hugging Face Hub (in the form of username/repo_name). hf_token (str | None): Hugging Face token to download the model from the Hub. Only required if the model is private. download_from_hub (bool): Whether to download the model from the Hugging Face Hub. Defaults to True. @@ -75,20 +70,19 @@ def from_credible_set( l2g_model = LocusToGeneModel.load_from_disk(model_path) # Prepare data - fm = L2GFeatureMatrix.generate_features( - features_list=features_list, + fm = L2GFeatureMatrix.from_features_list( + session, credible_set=credible_set, - study_index=study_index, - variant_gene=v2g, - colocalisation=coloc, + features_list=features_list, + features_input_loader=features_input_loader, ).fill_na() gwas_fm = L2GFeatureMatrix( _df=( fm._df.join( - credible_set.filter_by_study_type("gwas", study_index).df.select( - "studyLocusId" - ), + credible_set.filter_by_study_type( + "gwas", features_input_loader.get_dependency(StudyIndex) + ).df.select("studyLocusId"), on="studyLocusId", ) ), diff --git a/src/gentropy/l2g.py b/src/gentropy/l2g.py index 84c738a7b..4785a1125 100644 --- a/src/gentropy/l2g.py +++ b/src/gentropy/l2g.py @@ -12,12 +12,13 @@ from gentropy.common.utils import access_gcp_secret from gentropy.config import LocusToGeneConfig from gentropy.dataset.colocalisation import Colocalisation -from gentropy.dataset.l2g_feature_matrix import L2GFeatureInputLoader, L2GFeatureMatrix +from gentropy.dataset.l2g_feature_matrix import L2GFeatureMatrix from gentropy.dataset.l2g_gold_standard import L2GGoldStandard from gentropy.dataset.l2g_prediction import L2GPrediction from gentropy.dataset.study_index import StudyIndex from gentropy.dataset.study_locus import StudyLocus from gentropy.dataset.v2g import V2G +from gentropy.method.l2g.feature_factory import L2GFeatureInputLoader from gentropy.method.l2g.model import LocusToGeneModel from gentropy.method.l2g.trainer import LocusToGeneTrainer @@ -110,137 +111,146 @@ def __init__( ) if run_mode == "predict": - if not self.studies and self.v2g and self.coloc: - raise ValueError("Dependencies for predict mode not set.") self.run_predict() elif run_mode == "train": - if not gold_standard_curation_path and gene_interactions_path: - raise ValueError("Dependencies for train mode not set.") - self.gs_curation = self.session.spark.read.json(gold_standard_curation_path) - self.interactions = self.session.spark.read.parquet(gene_interactions_path) + self.gs_curation = ( + self.session.spark.read.json(gold_standard_curation_path) + if gold_standard_curation_path + else None + ) + self.interactions = ( + self.session.spark.read.parquet(gene_interactions_path) + if gene_interactions_path + else None + ) self.run_train() def run_predict(self) -> None: """Run the prediction step. Raises: - ValueError: If predictions_path is not set. + ValueError: If not all dependencies in prediction mode are set """ - if not self.predictions_path: - raise ValueError("predictions_path must be set for predict mode.") # TODO: IMPROVE - it is not correct that L2GPrediction outputs a feature matrix - FM should be written when training - predictions, feature_matrix = L2GPrediction.from_credible_set( - # TODO: rewrite this function to use the new FM generation - self.features_list, - self.credible_set, - self.studies, - self.v2g, - self.coloc, - self.session, - model_path=self.model_path, - hf_token=access_gcp_secret("hfhub-key", "open-targets-genetics-dev"), - download_from_hub=self.download_from_hub, - ) - if self.feature_matrix_path: - feature_matrix._df.write.mode(self.session.write_mode).parquet( - self.feature_matrix_path + if self.studies and self.v2g and self.coloc: + predictions, feature_matrix = L2GPrediction.from_credible_set( + self.session, + self.credible_set, + self.features_list, + self.features_input_loader, + model_path=self.model_path, + hf_token=access_gcp_secret("hfhub-key", "open-targets-genetics-dev"), + download_from_hub=self.download_from_hub, ) - predictions.df.write.mode(self.session.write_mode).parquet( - self.predictions_path - ) - self.session.logger.info(self.predictions_path) + if self.feature_matrix_path: + feature_matrix._df.write.mode(self.session.write_mode).parquet( + self.feature_matrix_path + ) + if self.predictions_path: + predictions.df.write.mode(self.session.write_mode).parquet( + self.predictions_path + ) + self.session.logger.info(self.predictions_path) + else: + raise ValueError("Dependencies for predict mode not set.") def run_train(self) -> None: - """Run the training step. + """Run the training step.""" + if ( + self.gs_curation + and self.interactions + and self.v2g + and self.wandb_run_name + and self.model_path + ): + wandb_key = access_gcp_secret("wandb-key", "open-targets-genetics-dev") + # Process gold standard and L2G features + data = self._generate_feature_matrix() - Raises: - ValueError: If gold_standard_curation_path, gene_interactions_path, or wandb_run_name are not set. - """ - if not (self.wandb_run_name and self.model_path): - raise ValueError( - "wandb_run_name, and a path to save the model must be set for train mode." + # Instantiate classifier and train model + l2g_model = LocusToGeneModel( + model=GradientBoostingClassifier(random_state=42), + hyperparameters=self.hyperparameters, ) - - wandb_key = access_gcp_secret("wandb-key", "open-targets-genetics-dev") - # Process gold standard and L2G features - data = self._generate_feature_matrix() - - # Instantiate classifier and train model - l2g_model = LocusToGeneModel( - model=GradientBoostingClassifier(random_state=42), - hyperparameters=self.hyperparameters, - ) - wandb_login(key=wandb_key) - trained_model = LocusToGeneTrainer(model=l2g_model, feature_matrix=data).train( - self.wandb_run_name - ) - if trained_model.training_data and trained_model.model: - trained_model.save(self.model_path) - if self.hf_hub_repo_id: - hf_hub_token = access_gcp_secret( - "hfhub-key", "open-targets-genetics-dev" - ) - trained_model.export_to_hugging_face_hub( - # we upload the model in the filesystem - self.model_path.split("/")[-1], - hf_hub_token, - data=trained_model.training_data._df.drop( - "goldStandardSet", "geneId" - ).toPandas(), - repo_id=self.hf_hub_repo_id, - commit_message="chore: update model", - ) + wandb_login(key=wandb_key) + trained_model = LocusToGeneTrainer( + model=l2g_model, feature_matrix=data + ).train(self.wandb_run_name) + if trained_model.training_data and trained_model.model and self.model_path: + trained_model.save(self.model_path) + if self.hf_hub_repo_id: + hf_hub_token = access_gcp_secret( + "hfhub-key", "open-targets-genetics-dev" + ) + trained_model.export_to_hugging_face_hub( + # we upload the model in the filesystem + self.model_path.split("/")[-1], + hf_hub_token, + data=trained_model.training_data._df.drop( + "goldStandardSet", "geneId" + ).toPandas(), + repo_id=self.hf_hub_repo_id, + commit_message="chore: update model", + ) def _generate_feature_matrix(self) -> L2GFeatureMatrix: """Generate the feature matrix for training. Returns: L2GFeatureMatrix: Feature matrix with gold standards annotated with features. + + Raises: + ValueError: If dependencies to build features are not set. """ - study_locus_overlap = StudyLocus( - _df=self.credible_set.df.join( - f.broadcast( - self.gs_curation.select( - StudyLocus.assign_study_locus_id( - f.col("association_info.otg_id"), # studyId - f.concat_ws( # variantId - "_", - f.col("sentinel_variant.locus_GRCh38.chromosome"), - f.col("sentinel_variant.locus_GRCh38.position"), - f.col("sentinel_variant.alleles.reference"), - f.col("sentinel_variant.alleles.alternative"), - ), - ).alias("studyLocusId"), - ) + if self.gs_curation and self.interactions and self.v2g and self.studies: + study_locus_overlap = StudyLocus( + _df=self.credible_set.df.join( + f.broadcast( + self.gs_curation.select( + StudyLocus.assign_study_locus_id( + f.col("association_info.otg_id"), # studyId + f.concat_ws( # variantId + "_", + f.col("sentinel_variant.locus_GRCh38.chromosome"), + f.col("sentinel_variant.locus_GRCh38.position"), + f.col("sentinel_variant.alleles.reference"), + f.col("sentinel_variant.alleles.alternative"), + ), + ).alias("studyLocusId"), + ) + ), + "studyLocusId", + "inner", ), - "studyLocusId", - "inner", - ), - _schema=StudyLocus.get_schema(), - ).find_overlaps(self.studies) - - gold_standards = L2GGoldStandard.from_otg_curation( - gold_standard_curation=self.gs_curation, - v2g=self.v2g, - study_locus_overlap=study_locus_overlap, - interactions=self.interactions, - ) + _schema=StudyLocus.get_schema(), + ).find_overlaps(self.studies) - # TODO: Should StudyLocus and GoldStandard have an `annotate_w_features` method? - fm = L2GFeatureMatrix.from_features_list( - self.session, self.features_list, self.features_input_loader - ) + gold_standards = L2GGoldStandard.from_otg_curation( + gold_standard_curation=self.gs_curation, + v2g=self.v2g, + study_locus_overlap=study_locus_overlap, + interactions=self.interactions, + ) - return ( - L2GFeatureMatrix( - _df=fm._df.join( - f.broadcast( - gold_standards.df.drop("variantId", "studyId", "sources") + # TODO: Should StudyLocus and GoldStandard have an `annotate_w_features` method? + fm = L2GFeatureMatrix.from_features_list( + self.session, + self.credible_set, + self.features_list, + self.features_input_loader, + ) + + return ( + L2GFeatureMatrix( + _df=fm._df.join( + f.broadcast( + gold_standards.df.drop("variantId", "studyId", "sources") + ), + on=["studyLocusId", "geneId"], + how="inner", ), - on=["studyLocusId", "geneId"], - how="inner", - ), + ) + .fill_na() + .select_features(self.features_list) ) - .fill_na() - .select_features(self.features_list) - ) + raise ValueError("Dependencies for train mode not set.") diff --git a/src/gentropy/method/l2g/feature_factory.py b/src/gentropy/method/l2g/feature_factory.py index 68e23876f..e58e5b4ae 100644 --- a/src/gentropy/method/l2g/feature_factory.py +++ b/src/gentropy/method/l2g/feature_factory.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Mapping +from typing import TYPE_CHECKING, Any, Iterator, Mapping import pyspark.sql.functions as f @@ -20,25 +20,43 @@ class L2GFeatureInputLoader: def __init__( self, - **kwargs: dict[str, Any], + **kwargs: Any, ) -> None: - """Initializes L2GFeatureInputLoader with the provided inputs and returns loaded dependencies as a list.""" + """Initializes L2GFeatureInputLoader with the provided inputs and returns loaded dependencies as a list. + + Args: + **kwargs (Any): keyword arguments with the name of the dependency and the dependency itself. + """ self.input_dependencies = [v for v in kwargs.values() if v is not None] def get_dependency(self, dependency_type: Any) -> Any: - """Returns the dependency that matches the provided type.""" + """Returns the dependency that matches the provided type. + + Args: + dependency_type (Any): type of the dependency to return. + + Returns: + Any: dependency that matches the provided type. + """ for dependency in self.input_dependencies: if isinstance(dependency, dependency_type): return dependency - def __iter__(self) -> list[Any]: - """Make the class iterable, returning the input dependencies list.""" + def __iter__(self) -> Iterator[dict[str, Any]]: + """Make the class iterable, returning the input dependencies list. + + Returns: + Iterator[dict[str, Any]]: list of input dependencies. + """ return iter(self.input_dependencies) def __repr__(self) -> str: """Return a string representation of the input dependencies. Useful for understanding the loader content without having to print the object attribute. + + Returns: + str: string representation of the input dependencies. """ return repr(self.input_dependencies) @@ -48,11 +66,14 @@ class DistanceTssMinimumFeature(L2GFeature): @classmethod def compute( - cls: type[DistanceTssMinimumFeature], input_dependency: V2G + cls: type[DistanceTssMinimumFeature], + credible_set: StudyLocus, + input_dependency: V2G, ) -> L2GFeature: """Computes the feature. Args: + credible_set (StudyLocus): Credible set dependency input_dependency (V2G): V2G dependency Returns: @@ -75,7 +96,7 @@ def compute( cls: type[DistanceTssMeanFeature], credible_set: StudyLocus, feature_dependency: V2G, - ) -> Any: + ) -> DistanceTssMeanFeature: """Computes the feature. Args: @@ -83,7 +104,7 @@ def compute( feature_dependency (V2G): Dataset that contains the distance information Returns: - L2GFeature: Feature dataset + DistanceTssMeanFeature: Feature dataset """ agg_expr = f.mean("weightedScore").alias("distanceTssMean") # Everything but expresion is common logic @@ -129,40 +150,41 @@ class FeatureFactory: } features_input_loader: L2GFeatureInputLoader - def __init__(self: FeatureFactory, credible_set: StudyLocus) -> None: + def __init__( + self: FeatureFactory, credible_set: StudyLocus, features_list: list[str] + ) -> None: """Initializes the factory. Args: credible_set (StudyLocus): credible sets to annotate + features_list (list[str]): list of features to compute. """ self.credible_set = credible_set + self.features_list = features_list - @classmethod def generate_features( - cls: type[FeatureFactory], + self: FeatureFactory, session: Session, - features_list: list[dict[str, str]], - credible_set_path: str, features_input_loader: L2GFeatureInputLoader, ) -> list[L2GFeature]: """Generates a feature matrix by reading an object with instructions on how to create the features. Args: session (Session): session object - features_list (list[dict[str, str]]): list of objects with 2 keys: 'name' and 'path'. - credible_set_path (str | None): path to credible set parquet file. features_input_loader (L2GFeatureInputLoader): object with required features dependencies. Returns: list[L2GFeature]: list of computed features. + + Raises: + ValueError: If feature not found. """ - cls.features_input_loader = features_input_loader computed_features = [] - for feature in features_list: - if feature["name"] in cls.feature_mapper: - computed_features.append(cls.compute_feature(feature["name"])) + for feature in self.features_list: + if feature in self.feature_mapper: + computed_features.append(self.compute_feature(feature)) else: - raise ValueError(f"Feature {feature['name']} not found.") + raise ValueError(f"Feature {feature} not found.") return computed_features def compute_feature(self: FeatureFactory, feature_name: str) -> L2GFeature: @@ -174,8 +196,12 @@ def compute_feature(self: FeatureFactory, feature_name: str) -> L2GFeature: Returns: L2GFeature: instantiated feature object """ + # Extract feature class and dependency type feature_cls = self.feature_mapper[feature_name] - # Filter features_input_loader to pass only the dependency that the feature needs feature_input_type = feature_cls.feature_dependency - feature_input = cls.features_input_loader.get_dependency(feature_input_type) - return feature_cls.compute(feature_input) + return feature_cls.compute( + credible_set=self.credible_set, + feature_dependency=self.features_input_loader.get_dependency( + feature_input_type + ), + ) From b1f607b4573d00125ae9a22a1f7660cc45e8e152 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Irene=20L=C3=B3pez?= Date: Mon, 9 Sep 2024 17:41:19 +0200 Subject: [PATCH 11/48] feat: l2gfeaturematrix.from_features_list working --- src/gentropy/dataset/l2g_feature.py | 2 +- src/gentropy/dataset/l2g_feature_matrix.py | 6 +++-- src/gentropy/dataset/l2g_prediction.py | 1 + src/gentropy/l2g.py | 4 +++- src/gentropy/method/l2g/feature_factory.py | 26 ++++++++++++---------- 5 files changed, 23 insertions(+), 16 deletions(-) diff --git a/src/gentropy/dataset/l2g_feature.py b/src/gentropy/dataset/l2g_feature.py index 2bbebbce7..c46d2cb0f 100644 --- a/src/gentropy/dataset/l2g_feature.py +++ b/src/gentropy/dataset/l2g_feature.py @@ -4,7 +4,7 @@ from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any from gentropy.common.schemas import parse_spark_schema from gentropy.dataset.dataset import Dataset diff --git a/src/gentropy/dataset/l2g_feature_matrix.py b/src/gentropy/dataset/l2g_feature_matrix.py index 3722ca15b..375be5727 100644 --- a/src/gentropy/dataset/l2g_feature_matrix.py +++ b/src/gentropy/dataset/l2g_feature_matrix.py @@ -22,7 +22,7 @@ def __init__( self, _df: DataFrame, features_list: list[str] | None = None, - mode: str = "train", + mode: str = "predict", ) -> None: """Post-initialisation to set the features list. If not provided, all columns except the fixed ones are used. @@ -59,6 +59,7 @@ def from_features_list( credible_set: StudyLocus, features_list: list[str], features_input_loader: L2GFeatureInputLoader, + mode: str, ) -> L2GFeatureMatrix: """Generate features from the gentropy datasets by calling the feature factory that will instantiate the corresponding features. @@ -67,6 +68,7 @@ def from_features_list( credible_set (StudyLocus): Credible set of study locus pairs to annotate features_list (list[str]): List of feature names to be computed. features_input_loader (L2GFeatureInputLoader): Object that contais features input. + mode (str): Mode of the feature matrix. Can be either "train" or "predict". If "train", the column with the gold standard set will be added to the feature matrix. Returns: L2GFeatureMatrix: L2G feature matrix dataset @@ -88,6 +90,7 @@ def from_features_list( "featureName", "featureValue", ), + mode=mode, ) def calculate_feature_missingness_rate( @@ -156,4 +159,3 @@ def select_features( ) return self raise ValueError("features_list cannot be None") - raise ValueError("features_list cannot be None") diff --git a/src/gentropy/dataset/l2g_prediction.py b/src/gentropy/dataset/l2g_prediction.py index 3a88ba1c1..a505b8a04 100644 --- a/src/gentropy/dataset/l2g_prediction.py +++ b/src/gentropy/dataset/l2g_prediction.py @@ -75,6 +75,7 @@ def from_credible_set( credible_set=credible_set, features_list=features_list, features_input_loader=features_input_loader, + mode="predict", ).fill_na() gwas_fm = L2GFeatureMatrix( diff --git a/src/gentropy/l2g.py b/src/gentropy/l2g.py index 755af9fea..edc5da149 100644 --- a/src/gentropy/l2g.py +++ b/src/gentropy/l2g.py @@ -233,12 +233,13 @@ def _generate_feature_matrix(self) -> L2GFeatureMatrix: interactions=self.interactions, ) - # TODO: Should StudyLocus and GoldStandard have an `annotate_w_features` method? + # TODO: Should StudyLocus and GoldStandard have an `annotate_w_features` method? Yes fm = L2GFeatureMatrix.from_features_list( self.session, self.credible_set, self.features_list, self.features_input_loader, + mode="predict", # here we don't have the goldStandardSet col ) return ( @@ -250,6 +251,7 @@ def _generate_feature_matrix(self) -> L2GFeatureMatrix: on=["studyLocusId", "geneId"], how="inner", ), + mode="train", # goldStandardSet col is there after joining with the GS ) .fill_na() .select_features(self.features_list) diff --git a/src/gentropy/method/l2g/feature_factory.py b/src/gentropy/method/l2g/feature_factory.py index e58e5b4ae..11df4ed2d 100644 --- a/src/gentropy/method/l2g/feature_factory.py +++ b/src/gentropy/method/l2g/feature_factory.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Iterator, Mapping +from typing import Any, Iterator, Mapping import pyspark.sql.functions as f @@ -10,9 +10,7 @@ from gentropy.common.spark_helpers import convert_from_wide_to_long from gentropy.dataset.l2g_feature import L2GFeature from gentropy.dataset.study_locus import StudyLocus - -if TYPE_CHECKING: - from gentropy.dataset.v2g import V2G +from gentropy.dataset.v2g import V2G class L2GFeatureInputLoader: @@ -68,13 +66,13 @@ class DistanceTssMinimumFeature(L2GFeature): def compute( cls: type[DistanceTssMinimumFeature], credible_set: StudyLocus, - input_dependency: V2G, + feature_dependency: V2G, ) -> L2GFeature: """Computes the feature. Args: credible_set (StudyLocus): Credible set dependency - input_dependency (V2G): V2G dependency + feature_dependency (V2G): V2G dependency Returns: L2GFeature: Feature dataset @@ -148,7 +146,6 @@ class FeatureFactory: # "distanceTssMinimum": DistanceTssMinimumFeature, "distanceTssMean": DistanceTssMeanFeature, } - features_input_loader: L2GFeatureInputLoader def __init__( self: FeatureFactory, credible_set: StudyLocus, features_list: list[str] @@ -182,16 +179,23 @@ def generate_features( computed_features = [] for feature in self.features_list: if feature in self.feature_mapper: - computed_features.append(self.compute_feature(feature)) + computed_features.append( + self.compute_feature(feature, features_input_loader) + ) else: raise ValueError(f"Feature {feature} not found.") return computed_features - def compute_feature(self: FeatureFactory, feature_name: str) -> L2GFeature: + def compute_feature( + self: FeatureFactory, + feature_name: str, + features_input_loader: L2GFeatureInputLoader, + ) -> L2GFeature: """Instantiates feature class. Args: feature_name (str): name of the feature + features_input_loader (L2GFeatureInputLoader): Object that contais features input. Returns: L2GFeature: instantiated feature object @@ -201,7 +205,5 @@ def compute_feature(self: FeatureFactory, feature_name: str) -> L2GFeature: feature_input_type = feature_cls.feature_dependency return feature_cls.compute( credible_set=self.credible_set, - feature_dependency=self.features_input_loader.get_dependency( - feature_input_type - ), + feature_dependency=features_input_loader.get_dependency(feature_input_type), ) From da200735f991141e6d5cedbc3d3cae17212f2d78 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Irene=20L=C3=B3pez?= Date: Tue, 10 Sep 2024 08:49:47 +0200 Subject: [PATCH 12/48] chore: comment out obsolete refs --- .../python_api/methods/l2g/feature_factory.md | 4 - tests/gentropy/method/test_locus_to_gene.py | 265 +++++++++--------- 2 files changed, 129 insertions(+), 140 deletions(-) diff --git a/docs/python_api/methods/l2g/feature_factory.md b/docs/python_api/methods/l2g/feature_factory.md index 35b4ed710..244796254 100644 --- a/docs/python_api/methods/l2g/feature_factory.md +++ b/docs/python_api/methods/l2g/feature_factory.md @@ -1,7 +1,3 @@ --- title: L2G Feature Factory --- - -::: gentropy.method.l2g.feature_factory.ColocalisationFactory - -::: gentropy.method.l2g.feature_factory.StudyLocusFactory diff --git a/tests/gentropy/method/test_locus_to_gene.py b/tests/gentropy/method/test_locus_to_gene.py index 460d65062..7698e99b0 100644 --- a/tests/gentropy/method/test_locus_to_gene.py +++ b/tests/gentropy/method/test_locus_to_gene.py @@ -7,17 +7,10 @@ import pytest from sklearn.ensemble import RandomForestClassifier -from gentropy.dataset.colocalisation import Colocalisation -from gentropy.dataset.l2g_feature import L2GFeature -from gentropy.dataset.study_index import StudyIndex -from gentropy.dataset.study_locus import StudyLocus -from gentropy.method.l2g.feature_factory import ColocalisationFactory, StudyLocusFactory from gentropy.method.l2g.model import LocusToGeneModel if TYPE_CHECKING: - from pyspark.sql import SparkSession - - from gentropy.dataset.v2g import V2G + pass @pytest.fixture(scope="module") @@ -26,131 +19,131 @@ def model() -> LocusToGeneModel: return LocusToGeneModel(model=RandomForestClassifier()) -class TestColocalisationFactory: - """Test the ColocalisationFactory methods.""" - - def test_get_max_coloc_per_credible_set( - self: TestColocalisationFactory, - mock_study_locus: StudyLocus, - mock_study_index: StudyIndex, - mock_colocalisation: Colocalisation, - ) -> None: - """Test the function that extracts the maximum log likelihood ratio for each pair of overlapping study-locus returns the right data type.""" - coloc_features = ColocalisationFactory._get_max_coloc_per_credible_set( - mock_colocalisation, - mock_study_locus, - mock_study_index, - ) - assert isinstance( - coloc_features, L2GFeature - ), "Unexpected type returned from _get_max_coloc_per_credible_set" - - def test_get_max_coloc_per_credible_set_semantic( - self: TestColocalisationFactory, - spark: SparkSession, - ) -> None: - """Test logic of the function that extracts the maximum log likelihood ratio for each pair of overlapping study-locus.""" - # Prepare mock datasets based on 2 associations - credset = StudyLocus( - _df=spark.createDataFrame( - # 2 associations with a common variant in the locus - [ - { - "studyLocusId": 1, - "variantId": "lead1", - "studyId": "study1", # this is a GWAS - "locus": [ - {"variantId": "commonTag", "posteriorProbability": 0.9}, - ], - "chromosome": "1", - }, - { - "studyLocusId": 2, - "variantId": "lead2", - "studyId": "study2", # this is a eQTL study - "locus": [ - {"variantId": "commonTag", "posteriorProbability": 0.9}, - ], - "chromosome": "1", - }, - ], - StudyLocus.get_schema(), - ), - _schema=StudyLocus.get_schema(), - ) - - studies = StudyIndex( - _df=spark.createDataFrame( - [ - { - "studyId": "study1", - "studyType": "gwas", - "traitFromSource": "trait1", - "projectId": "project1", - }, - { - "studyId": "study2", - "studyType": "eqtl", - "geneId": "gene1", - "traitFromSource": "trait2", - "projectId": "project2", - }, - ] - ), - _schema=StudyIndex.get_schema(), - ) - coloc = Colocalisation( - _df=spark.createDataFrame( - [ - { - "leftStudyLocusId": 1, - "rightStudyLocusId": 2, - "chromosome": "1", - "colocalisationMethod": "eCAVIAR", - "numberColocalisingVariants": 1, - "clpp": 0.81, # 0.9*0.9 - "log2h4h3": None, - } - ], - schema=Colocalisation.get_schema(), - ), - _schema=Colocalisation.get_schema(), - ) - expected_coloc_features_df = spark.createDataFrame( - [ - (1, "gene1", "eqtlColocClppMaximum", 0.81), - (1, "gene1", "eqtlColocClppMaximumNeighborhood", -4.0), - ], - L2GFeature.get_schema(), - ) - # Test - coloc_features = ColocalisationFactory._get_max_coloc_per_credible_set( - coloc, - credset, - studies, - ) - assert coloc_features.df.collect() == expected_coloc_features_df.collect() - - -class TestStudyLocusFactory: - """Test the StudyLocusFactory methods.""" - - def test_get_tss_distance_features( - self: TestStudyLocusFactory, mock_study_locus: StudyLocus, mock_v2g: V2G - ) -> None: - """Test the function that extracts the distance to the TSS.""" - tss_distance = StudyLocusFactory._get_tss_distance_features( - mock_study_locus, mock_v2g - ) - assert isinstance( - tss_distance, L2GFeature - ), "Unexpected model type returned from _get_tss_distance_features" - - def test_get_vep_features( - self: TestStudyLocusFactory, mock_study_locus: StudyLocus, mock_v2g: V2G - ) -> None: - """Test the function that extracts the VEP features.""" - vep_features = StudyLocusFactory._get_vep_features(mock_study_locus, mock_v2g) - assert isinstance( - vep_features, L2GFeature - ), "Unexpected model type returned from _get_vep_features" +# class TestColocalisationFactory: +# """Test the ColocalisationFactory methods.""" + +# def test_get_max_coloc_per_credible_set( +# self: TestColocalisationFactory, +# mock_study_locus: StudyLocus, +# mock_study_index: StudyIndex, +# mock_colocalisation: Colocalisation, +# ) -> None: +# """Test the function that extracts the maximum log likelihood ratio for each pair of overlapping study-locus returns the right data type.""" +# coloc_features = ColocalisationFactory._get_max_coloc_per_credible_set( +# mock_colocalisation, +# mock_study_locus, +# mock_study_index, +# ) +# assert isinstance( +# coloc_features, L2GFeature +# ), "Unexpected type returned from _get_max_coloc_per_credible_set" + +# def test_get_max_coloc_per_credible_set_semantic( +# self: TestColocalisationFactory, +# spark: SparkSession, +# ) -> None: +# """Test logic of the function that extracts the maximum log likelihood ratio for each pair of overlapping study-locus.""" +# # Prepare mock datasets based on 2 associations +# credset = StudyLocus( +# _df=spark.createDataFrame( +# # 2 associations with a common variant in the locus +# [ +# { +# "studyLocusId": 1, +# "variantId": "lead1", +# "studyId": "study1", # this is a GWAS +# "locus": [ +# {"variantId": "commonTag", "posteriorProbability": 0.9}, +# ], +# "chromosome": "1", +# }, +# { +# "studyLocusId": 2, +# "variantId": "lead2", +# "studyId": "study2", # this is a eQTL study +# "locus": [ +# {"variantId": "commonTag", "posteriorProbability": 0.9}, +# ], +# "chromosome": "1", +# }, +# ], +# StudyLocus.get_schema(), +# ), +# _schema=StudyLocus.get_schema(), +# ) + +# studies = StudyIndex( +# _df=spark.createDataFrame( +# [ +# { +# "studyId": "study1", +# "studyType": "gwas", +# "traitFromSource": "trait1", +# "projectId": "project1", +# }, +# { +# "studyId": "study2", +# "studyType": "eqtl", +# "geneId": "gene1", +# "traitFromSource": "trait2", +# "projectId": "project2", +# }, +# ] +# ), +# _schema=StudyIndex.get_schema(), +# ) +# coloc = Colocalisation( +# _df=spark.createDataFrame( +# [ +# { +# "leftStudyLocusId": 1, +# "rightStudyLocusId": 2, +# "chromosome": "1", +# "colocalisationMethod": "eCAVIAR", +# "numberColocalisingVariants": 1, +# "clpp": 0.81, # 0.9*0.9 +# "log2h4h3": None, +# } +# ], +# schema=Colocalisation.get_schema(), +# ), +# _schema=Colocalisation.get_schema(), +# ) +# expected_coloc_features_df = spark.createDataFrame( +# [ +# (1, "gene1", "eqtlColocClppMaximum", 0.81), +# (1, "gene1", "eqtlColocClppMaximumNeighborhood", -4.0), +# ], +# L2GFeature.get_schema(), +# ) +# # Test +# coloc_features = ColocalisationFactory._get_max_coloc_per_credible_set( +# coloc, +# credset, +# studies, +# ) +# assert coloc_features.df.collect() == expected_coloc_features_df.collect() + + +# class TestStudyLocusFactory: +# """Test the StudyLocusFactory methods.""" + +# def test_get_tss_distance_features( +# self: TestStudyLocusFactory, mock_study_locus: StudyLocus, mock_v2g: V2G +# ) -> None: +# """Test the function that extracts the distance to the TSS.""" +# tss_distance = StudyLocusFactory._get_tss_distance_features( +# mock_study_locus, mock_v2g +# ) +# assert isinstance( +# tss_distance, L2GFeature +# ), "Unexpected model type returned from _get_tss_distance_features" + +# def test_get_vep_features( +# self: TestStudyLocusFactory, mock_study_locus: StudyLocus, mock_v2g: V2G +# ) -> None: +# """Test the function that extracts the VEP features.""" +# vep_features = StudyLocusFactory._get_vep_features(mock_study_locus, mock_v2g) +# assert isinstance( +# vep_features, L2GFeature +# ), "Unexpected model type returned from _get_vep_features" From d06c05945fab4e67a9f9b4ae99dd814e05733f25 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Irene=20L=C3=B3pez?= Date: Tue, 10 Sep 2024 09:06:34 +0200 Subject: [PATCH 13/48] chore(L2GFeatureMatrix): change `mode` attribute to `with_gold_standard` --- src/gentropy/dataset/l2g_feature_matrix.py | 20 +++++++------------- src/gentropy/dataset/l2g_prediction.py | 4 ++-- src/gentropy/l2g.py | 4 ++-- tests/gentropy/conftest.py | 2 +- tests/gentropy/dataset/test_l2g.py | 2 +- 5 files changed, 13 insertions(+), 19 deletions(-) diff --git a/src/gentropy/dataset/l2g_feature_matrix.py b/src/gentropy/dataset/l2g_feature_matrix.py index 375be5727..4b4b861a5 100644 --- a/src/gentropy/dataset/l2g_feature_matrix.py +++ b/src/gentropy/dataset/l2g_feature_matrix.py @@ -22,23 +22,17 @@ def __init__( self, _df: DataFrame, features_list: list[str] | None = None, - mode: str = "predict", + with_gold_standard: bool = False, ) -> None: """Post-initialisation to set the features list. If not provided, all columns except the fixed ones are used. Args: _df (DataFrame): Feature matrix dataset features_list (list[str] | None): List of features to use. If None, all possible features are used. - mode (str): Mode of the feature matrix. Defaults to "train". Can be either "train" or "predict". - - Raises: - ValueError: If the mode is neither 'train' nor 'predict'. + with_gold_standard (bool): Whether to include the gold standard set in the feature matrix. """ - if mode not in ["train", "predict"]: - raise ValueError("Mode should be either 'train' or 'predict'") - self.fixed_cols = ["studyLocusId", "geneId"] - if mode == "train": + if with_gold_standard: self.fixed_cols.append("goldStandardSet") self.features_list = features_list or [ @@ -56,10 +50,10 @@ def __init__( def from_features_list( cls: Type[L2GFeatureMatrix], session: Session, - credible_set: StudyLocus, + credible_set: StudyLocus, # TODO: union this with gold standard features_list: list[str], features_input_loader: L2GFeatureInputLoader, - mode: str, + with_gold_standard: bool, ) -> L2GFeatureMatrix: """Generate features from the gentropy datasets by calling the feature factory that will instantiate the corresponding features. @@ -68,7 +62,7 @@ def from_features_list( credible_set (StudyLocus): Credible set of study locus pairs to annotate features_list (list[str]): List of feature names to be computed. features_input_loader (L2GFeatureInputLoader): Object that contais features input. - mode (str): Mode of the feature matrix. Can be either "train" or "predict". If "train", the column with the gold standard set will be added to the feature matrix. + with_gold_standard (bool): Whether to include the gold standard set in the feature matrix. Returns: L2GFeatureMatrix: L2G feature matrix dataset @@ -90,7 +84,7 @@ def from_features_list( "featureName", "featureValue", ), - mode=mode, + with_gold_standard=with_gold_standard, ) def calculate_feature_missingness_rate( diff --git a/src/gentropy/dataset/l2g_prediction.py b/src/gentropy/dataset/l2g_prediction.py index a505b8a04..ad23b53cf 100644 --- a/src/gentropy/dataset/l2g_prediction.py +++ b/src/gentropy/dataset/l2g_prediction.py @@ -75,7 +75,7 @@ def from_credible_set( credible_set=credible_set, features_list=features_list, features_input_loader=features_input_loader, - mode="predict", + with_gold_standard=False, ).fill_na() gwas_fm = L2GFeatureMatrix( @@ -87,7 +87,7 @@ def from_credible_set( on="studyLocusId", ) ), - mode="predict", + with_gold_standard=False, ).select_features(features_list) return ( l2g_model.predict(gwas_fm, session), diff --git a/src/gentropy/l2g.py b/src/gentropy/l2g.py index edc5da149..4dab48ed3 100644 --- a/src/gentropy/l2g.py +++ b/src/gentropy/l2g.py @@ -239,7 +239,7 @@ def _generate_feature_matrix(self) -> L2GFeatureMatrix: self.credible_set, self.features_list, self.features_input_loader, - mode="predict", # here we don't have the goldStandardSet col + with_gold_standard=False, ) return ( @@ -251,7 +251,7 @@ def _generate_feature_matrix(self) -> L2GFeatureMatrix: on=["studyLocusId", "geneId"], how="inner", ), - mode="train", # goldStandardSet col is there after joining with the GS + with_gold_standard=True, # goldStandardSet col is there after joining with the GS ) .fill_na() .select_features(self.features_list) diff --git a/tests/gentropy/conftest.py b/tests/gentropy/conftest.py index 1dd43a8fd..93ee38471 100644 --- a/tests/gentropy/conftest.py +++ b/tests/gentropy/conftest.py @@ -592,7 +592,7 @@ def mock_l2g_feature_matrix(spark: SparkSession) -> L2GFeatureMatrix: ], "studyLocusId LONG, geneId STRING, distanceTssMean FLOAT, distanceTssMinimum FLOAT", ), - mode="predict", + with_gold_standard=False, ) diff --git a/tests/gentropy/dataset/test_l2g.py b/tests/gentropy/dataset/test_l2g.py index 496398945..d37ce5a4a 100644 --- a/tests/gentropy/dataset/test_l2g.py +++ b/tests/gentropy/dataset/test_l2g.py @@ -166,7 +166,7 @@ def test_l2g_feature_constructor_with_schema_mismatch( ], "studyLocusId LONG, geneId STRING, distanceTssMean DOUBLE", ), - mode="predict", + with_gold_standard=False, ) assert ( fm._df.schema["distanceTssMean"].dataType == FloatType() From 0a007a72882b0224872b4324411f561d429c5b4e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Irene=20L=C3=B3pez?= Date: Tue, 10 Sep 2024 09:32:33 +0200 Subject: [PATCH 14/48] refactor(l2g): move feature matrix writing to training module --- config/step/ot_locus_to_gene_predict.yaml | 1 - config/step/ot_locus_to_gene_train.yaml | 2 ++ src/gentropy/config.py | 1 + src/gentropy/dataset/l2g_prediction.py | 10 +++------ src/gentropy/l2g.py | 27 ++++++++++++++--------- 5 files changed, 23 insertions(+), 18 deletions(-) diff --git a/config/step/ot_locus_to_gene_predict.yaml b/config/step/ot_locus_to_gene_predict.yaml index c3cb88b59..97080223a 100644 --- a/config/step/ot_locus_to_gene_predict.yaml +++ b/config/step/ot_locus_to_gene_predict.yaml @@ -4,7 +4,6 @@ defaults: run_mode: predict model_path: null predictions_path: ${datasets.l2g_predictions} -feature_matrix_path: ${datasets.l2g_feature_matrix} credible_set_path: ${datasets.credible_set} variant_gene_path: ${datasets.variant_to_gene} colocalisation_path: ${datasets.colocalisation} diff --git a/config/step/ot_locus_to_gene_train.yaml b/config/step/ot_locus_to_gene_train.yaml index cf6d0a574..e1504fff0 100644 --- a/config/step/ot_locus_to_gene_train.yaml +++ b/config/step/ot_locus_to_gene_train.yaml @@ -12,6 +12,8 @@ colocalisation_path: ${datasets.colocalisation} study_index_path: ${datasets.study_index} gold_standard_curation_path: ${datasets.l2g_gold_standard_curation} gene_interactions_path: ${datasets.gene_interactions} +feature_matrix_path: ${datasets.l2g_feature_matrix} +write_feature_matrix: true hyperparameters: n_estimators: 100 max_depth: 5 diff --git a/src/gentropy/config.py b/src/gentropy/config.py index f85adf6cc..70151f8f3 100644 --- a/src/gentropy/config.py +++ b/src/gentropy/config.py @@ -236,6 +236,7 @@ class LocusToGeneConfig(StepConfig): wandb_run_name: str | None = None hf_hub_repo_id: str | None = "opentargets/locus_to_gene" download_from_hub: bool = True + write_feature_matrix: bool = True _target_: str = "gentropy.l2g.LocusToGeneStep" diff --git a/src/gentropy/dataset/l2g_prediction.py b/src/gentropy/dataset/l2g_prediction.py index ad23b53cf..6ff30bf5e 100644 --- a/src/gentropy/dataset/l2g_prediction.py +++ b/src/gentropy/dataset/l2g_prediction.py @@ -46,7 +46,7 @@ def from_credible_set( model_path: str | None, hf_token: str | None = None, download_from_hub: bool = True, - ) -> tuple[L2GPrediction, L2GFeatureMatrix]: + ) -> L2GPrediction: """Extract L2G predictions for a set of credible sets derived from GWAS. Args: @@ -59,7 +59,7 @@ def from_credible_set( download_from_hub (bool): Whether to download the model from the Hugging Face Hub. Defaults to True. Returns: - tuple[L2GPrediction, L2GFeatureMatrix]: L2G dataset and feature matrix limited to GWAS study only. + L2GPrediction: L2G scores for a set of credible sets. """ # Load the model if download_from_hub: @@ -87,9 +87,5 @@ def from_credible_set( on="studyLocusId", ) ), - with_gold_standard=False, ).select_features(features_list) - return ( - l2g_model.predict(gwas_fm, session), - gwas_fm, - ) + return l2g_model.predict(gwas_fm, session) diff --git a/src/gentropy/l2g.py b/src/gentropy/l2g.py index 4dab48ed3..3c3f09234 100644 --- a/src/gentropy/l2g.py +++ b/src/gentropy/l2g.py @@ -44,6 +44,7 @@ def __init__( gene_interactions_path: str | None = None, predictions_path: str | None = None, feature_matrix_path: str | None = None, + write_feature_matrix: bool, hf_hub_repo_id: str | None = LocusToGeneConfig().hf_hub_repo_id, ) -> None: """Initialise the step and run the logic based on mode. @@ -64,6 +65,7 @@ def __init__( gene_interactions_path (str | None): Path to the gene interactions dataset predictions_path (str | None): Path to the L2G predictions output dataset feature_matrix_path (str | None): Path to the L2G feature matrix output dataset + write_feature_matrix (bool): Whether to write the full feature matrix to the filesystem hf_hub_repo_id (str | None): Hugging Face Hub repository ID. If provided, the model will be uploaded to Hugging Face. Raises: @@ -131,9 +133,8 @@ def run_predict(self) -> None: Raises: ValueError: If not all dependencies in prediction mode are set """ - # TODO: IMPROVE - it is not correct that L2GPrediction outputs a feature matrix - FM should be written when training if self.studies and self.v2g and self.coloc: - predictions, feature_matrix = L2GPrediction.from_credible_set( + predictions = L2GPrediction.from_credible_set( self.session, self.credible_set, self.features_list, @@ -142,10 +143,6 @@ def run_predict(self) -> None: hf_token=access_gcp_secret("hfhub-key", "open-targets-genetics-dev"), download_from_hub=self.download_from_hub, ) - if self.feature_matrix_path: - feature_matrix._df.write.mode(self.session.write_mode).parquet( - self.feature_matrix_path - ) if self.predictions_path: predictions.df.write.mode(self.session.write_mode).parquet( self.predictions_path @@ -165,7 +162,7 @@ def run_train(self) -> None: ): wandb_key = access_gcp_secret("wandb-key", "open-targets-genetics-dev") # Process gold standard and L2G features - data = self._generate_feature_matrix() + data = self._generate_feature_matrix(write_feature_matrix=True) # Instantiate classifier and train model l2g_model = LocusToGeneModel( @@ -193,13 +190,17 @@ def run_train(self) -> None: commit_message="chore: update model", ) - def _generate_feature_matrix(self) -> L2GFeatureMatrix: - """Generate the feature matrix for training. + def _generate_feature_matrix(self, write_feature_matrix: bool) -> L2GFeatureMatrix: + """Generate the feature matrix of annotated gold standards. + + Args: + write_feature_matrix (bool): Whether to write the feature matrix for all credible sets to disk Returns: L2GFeatureMatrix: Feature matrix with gold standards annotated with features. Raises: + ValueError: If write_feature_matrix is set to True but a path is not provided. ValueError: If dependencies to build features are not set. """ if self.gs_curation and self.interactions and self.v2g and self.studies: @@ -239,8 +240,14 @@ def _generate_feature_matrix(self) -> L2GFeatureMatrix: self.credible_set, self.features_list, self.features_input_loader, - with_gold_standard=False, + False, ) + if write_feature_matrix: + if not self.feature_matrix_path: + raise ValueError("feature_matrix_path must be set.") + fm._df.write.mode(self.session.write_mode).parquet( + self.feature_matrix_path + ) return ( L2GFeatureMatrix( From abfdf220d24efc293e2d8236c0f3f2b0e045e170 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Irene=20L=C3=B3pez?= Date: Tue, 10 Sep 2024 09:46:28 +0200 Subject: [PATCH 15/48] feat(L2GFeatureMatrix): accept L2GGoldStandard or StudyLocus as inputs --- src/gentropy/dataset/l2g_feature.py | 7 ++++-- src/gentropy/dataset/l2g_feature_matrix.py | 11 +++++----- src/gentropy/dataset/l2g_prediction.py | 3 +-- src/gentropy/l2g.py | 1 - src/gentropy/method/l2g/feature_factory.py | 25 +++++++++++++--------- 5 files changed, 26 insertions(+), 21 deletions(-) diff --git a/src/gentropy/dataset/l2g_feature.py b/src/gentropy/dataset/l2g_feature.py index c46d2cb0f..1572fe542 100644 --- a/src/gentropy/dataset/l2g_feature.py +++ b/src/gentropy/dataset/l2g_feature.py @@ -12,6 +12,7 @@ if TYPE_CHECKING: from pyspark.sql.types import StructType + from gentropy.dataset.l2g_gold_standard import L2GGoldStandard from gentropy.dataset.study_locus import StudyLocus @@ -46,12 +47,14 @@ def get_schema(cls: type[L2GFeature]) -> StructType: @classmethod @abstractmethod def compute( - cls: type[L2GFeature], credible_set: StudyLocus, feature_dependency: Any + cls: type[L2GFeature], + study_loci_to_annotate: StudyLocus | L2GGoldStandard, + feature_dependency: Any, ) -> L2GFeature: """Computes the L2GFeature dataset. Args: - credible_set (StudyLocus): The credible set that will be used for annotation + study_loci_to_annotate (StudyLocus | L2GGoldStandard): The dataset containing study loci that will be used for annotation feature_dependency (Any): The dependency that the L2GFeature class needs to compute the feature Returns: L2GFeature: a L2GFeature dataset diff --git a/src/gentropy/dataset/l2g_feature_matrix.py b/src/gentropy/dataset/l2g_feature_matrix.py index 4b4b861a5..113e14cf8 100644 --- a/src/gentropy/dataset/l2g_feature_matrix.py +++ b/src/gentropy/dataset/l2g_feature_matrix.py @@ -12,6 +12,7 @@ from pyspark.sql import DataFrame from gentropy.common.session import Session + from gentropy.dataset.l2g_gold_standard import L2GGoldStandard from gentropy.dataset.study_locus import StudyLocus @@ -50,19 +51,17 @@ def __init__( def from_features_list( cls: Type[L2GFeatureMatrix], session: Session, - credible_set: StudyLocus, # TODO: union this with gold standard + study_loci_to_annotate: StudyLocus | L2GGoldStandard, features_list: list[str], features_input_loader: L2GFeatureInputLoader, - with_gold_standard: bool, ) -> L2GFeatureMatrix: """Generate features from the gentropy datasets by calling the feature factory that will instantiate the corresponding features. Args: session (Session): Session object - credible_set (StudyLocus): Credible set of study locus pairs to annotate + study_loci_to_annotate (StudyLocus | L2GGoldStandard): Study locus pairs to annotate features_list (list[str]): List of feature names to be computed. features_input_loader (L2GFeatureInputLoader): Object that contais features input. - with_gold_standard (bool): Whether to include the gold standard set in the feature matrix. Returns: L2GFeatureMatrix: L2G feature matrix dataset @@ -73,7 +72,7 @@ def from_features_list( # Compute all features and merge them into a single dataframe feature.df for feature in FeatureFactory( - credible_set, features_list + study_loci_to_annotate, features_list ).generate_features(session, features_input_loader) ], ) @@ -84,7 +83,7 @@ def from_features_list( "featureName", "featureValue", ), - with_gold_standard=with_gold_standard, + with_gold_standard=isinstance(study_loci_to_annotate, L2GGoldStandard), ) def calculate_feature_missingness_rate( diff --git a/src/gentropy/dataset/l2g_prediction.py b/src/gentropy/dataset/l2g_prediction.py index 6ff30bf5e..bb3ab3d4a 100644 --- a/src/gentropy/dataset/l2g_prediction.py +++ b/src/gentropy/dataset/l2g_prediction.py @@ -72,10 +72,9 @@ def from_credible_set( # Prepare data fm = L2GFeatureMatrix.from_features_list( session, - credible_set=credible_set, + study_loci_to_annotate=credible_set, features_list=features_list, features_input_loader=features_input_loader, - with_gold_standard=False, ).fill_na() gwas_fm = L2GFeatureMatrix( diff --git a/src/gentropy/l2g.py b/src/gentropy/l2g.py index 3c3f09234..a8c703242 100644 --- a/src/gentropy/l2g.py +++ b/src/gentropy/l2g.py @@ -240,7 +240,6 @@ def _generate_feature_matrix(self, write_feature_matrix: bool) -> L2GFeatureMatr self.credible_set, self.features_list, self.features_input_loader, - False, ) if write_feature_matrix: if not self.feature_matrix_path: diff --git a/src/gentropy/method/l2g/feature_factory.py b/src/gentropy/method/l2g/feature_factory.py index 11df4ed2d..94a15b480 100644 --- a/src/gentropy/method/l2g/feature_factory.py +++ b/src/gentropy/method/l2g/feature_factory.py @@ -9,6 +9,7 @@ from gentropy.common.session import Session from gentropy.common.spark_helpers import convert_from_wide_to_long from gentropy.dataset.l2g_feature import L2GFeature +from gentropy.dataset.l2g_gold_standard import L2GGoldStandard from gentropy.dataset.study_locus import StudyLocus from gentropy.dataset.v2g import V2G @@ -65,14 +66,14 @@ class DistanceTssMinimumFeature(L2GFeature): @classmethod def compute( cls: type[DistanceTssMinimumFeature], - credible_set: StudyLocus, + study_loci_to_annotate: StudyLocus | L2GGoldStandard, feature_dependency: V2G, ) -> L2GFeature: """Computes the feature. Args: - credible_set (StudyLocus): Credible set dependency - feature_dependency (V2G): V2G dependency + study_loci_to_annotate (StudyLocus | L2GGoldStandard): The dataset containing study loci that will be used for annotation + feature_dependency (V2G): Dataset that contains the distance information Returns: L2GFeature: Feature dataset @@ -92,13 +93,13 @@ class DistanceTssMeanFeature(L2GFeature): @classmethod def compute( cls: type[DistanceTssMeanFeature], - credible_set: StudyLocus, + study_loci_to_annotate: StudyLocus | L2GGoldStandard, feature_dependency: V2G, ) -> DistanceTssMeanFeature: """Computes the feature. Args: - credible_set (StudyLocus): Credible set dependency + study_loci_to_annotate (StudyLocus | L2GGoldStandard): The dataset containing study loci that will be used for annotation feature_dependency (V2G): Dataset that contains the distance information Returns: @@ -108,7 +109,9 @@ def compute( # Everything but expresion is common logic v2g = feature_dependency.df.filter(f.col("datasourceId") == "canonical_tss") wide_df = ( - credible_set.df.withColumn("variantInLocus", f.explode_outer("locus")) + study_loci_to_annotate.df.withColumn( + "variantInLocus", f.explode_outer("locus") + ) .select( "studyLocusId", f.col("variantInLocus.variantId").alias("variantInLocusId"), @@ -148,15 +151,17 @@ class FeatureFactory: } def __init__( - self: FeatureFactory, credible_set: StudyLocus, features_list: list[str] + self: FeatureFactory, + study_loci_to_annotate: StudyLocus | L2GGoldStandard, + features_list: list[str], ) -> None: """Initializes the factory. Args: - credible_set (StudyLocus): credible sets to annotate + study_loci_to_annotate (StudyLocus | L2GGoldStandard): The dataset containing study loci that will be used for annotation features_list (list[str]): list of features to compute. """ - self.credible_set = credible_set + self.study_loci_to_annotate = study_loci_to_annotate self.features_list = features_list def generate_features( @@ -204,6 +209,6 @@ def compute_feature( feature_cls = self.feature_mapper[feature_name] feature_input_type = feature_cls.feature_dependency return feature_cls.compute( - credible_set=self.credible_set, + study_loci_to_annotate=self.study_loci_to_annotate, feature_dependency=features_input_loader.get_dependency(feature_input_type), ) From 1eed6f3fbdf9d1b754d3d800fe3d37d3c98e7c9f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Irene=20L=C3=B3pez?= Date: Tue, 10 Sep 2024 10:39:13 +0200 Subject: [PATCH 16/48] feat: implement methods to build a feature matrix based on a studylocus/L2GGoldStandard instance --- src/gentropy/dataset/l2g_feature_matrix.py | 5 +---- src/gentropy/dataset/l2g_gold_standard.py | 25 ++++++++++++++++++++++ src/gentropy/dataset/l2g_prediction.py | 1 - src/gentropy/dataset/study_locus.py | 24 +++++++++++++++++++++ src/gentropy/l2g.py | 19 +++------------- src/gentropy/method/l2g/feature_factory.py | 3 --- 6 files changed, 53 insertions(+), 24 deletions(-) diff --git a/src/gentropy/dataset/l2g_feature_matrix.py b/src/gentropy/dataset/l2g_feature_matrix.py index 113e14cf8..b64961ca5 100644 --- a/src/gentropy/dataset/l2g_feature_matrix.py +++ b/src/gentropy/dataset/l2g_feature_matrix.py @@ -11,7 +11,6 @@ if TYPE_CHECKING: from pyspark.sql import DataFrame - from gentropy.common.session import Session from gentropy.dataset.l2g_gold_standard import L2GGoldStandard from gentropy.dataset.study_locus import StudyLocus @@ -50,7 +49,6 @@ def __init__( @classmethod def from_features_list( cls: Type[L2GFeatureMatrix], - session: Session, study_loci_to_annotate: StudyLocus | L2GGoldStandard, features_list: list[str], features_input_loader: L2GFeatureInputLoader, @@ -58,7 +56,6 @@ def from_features_list( """Generate features from the gentropy datasets by calling the feature factory that will instantiate the corresponding features. Args: - session (Session): Session object study_loci_to_annotate (StudyLocus | L2GGoldStandard): Study locus pairs to annotate features_list (list[str]): List of feature names to be computed. features_input_loader (L2GFeatureInputLoader): Object that contais features input. @@ -73,7 +70,7 @@ def from_features_list( feature.df for feature in FeatureFactory( study_loci_to_annotate, features_list - ).generate_features(session, features_input_loader) + ).generate_features(features_input_loader) ], ) return cls( diff --git a/src/gentropy/dataset/l2g_gold_standard.py b/src/gentropy/dataset/l2g_gold_standard.py index 5bc48413c..89f4c5f5d 100644 --- a/src/gentropy/dataset/l2g_gold_standard.py +++ b/src/gentropy/dataset/l2g_gold_standard.py @@ -1,4 +1,5 @@ """L2G gold standard dataset.""" + from __future__ import annotations from dataclasses import dataclass @@ -15,6 +16,7 @@ from pyspark.sql import DataFrame from pyspark.sql.types import StructType + from gentropy.dataset.l2g_feature_matrix import L2GFeatureMatrix from gentropy.dataset.study_locus_overlap import StudyLocusOverlap from gentropy.dataset.v2g import V2G @@ -100,6 +102,29 @@ def process_gene_interactions( "scoring as score", ) + def build_feature_matrix( + self: L2GGoldStandard, + full_feature_matrix: L2GFeatureMatrix, + ) -> L2GFeatureMatrix: + """Return a feature matrix for study loci in the gold standard. + + Args: + full_feature_matrix (L2GFeatureMatrix): Feature matrix for all study loci to join on + + Returns: + L2GFeatureMatrix: Feature matrix for study loci in the gold standard + """ + from gentropy.dataset.l2g_feature_matrix import L2GFeatureMatrix + + return L2GFeatureMatrix( + _df=full_feature_matrix._df.join( + f.broadcast(self.df.drop("variantId", "studyId", "sources")), + on=["studyLocusId", "geneId"], + how="inner", + ), + with_gold_standard=True, + ) + def filter_unique_associations( self: L2GGoldStandard, study_locus_overlap: StudyLocusOverlap, diff --git a/src/gentropy/dataset/l2g_prediction.py b/src/gentropy/dataset/l2g_prediction.py index bb3ab3d4a..86210097c 100644 --- a/src/gentropy/dataset/l2g_prediction.py +++ b/src/gentropy/dataset/l2g_prediction.py @@ -71,7 +71,6 @@ def from_credible_set( # Prepare data fm = L2GFeatureMatrix.from_features_list( - session, study_loci_to_annotate=credible_set, features_list=features_list, features_input_loader=features_input_loader, diff --git a/src/gentropy/dataset/study_locus.py b/src/gentropy/dataset/study_locus.py index ee488b019..3bf2aff70 100644 --- a/src/gentropy/dataset/study_locus.py +++ b/src/gentropy/dataset/study_locus.py @@ -24,9 +24,11 @@ from pyspark.sql import Column, DataFrame from pyspark.sql.types import StructType + from gentropy.dataset.l2g_feature_matrix import L2GFeatureMatrix from gentropy.dataset.ld_index import LDIndex from gentropy.dataset.study_index import StudyIndex from gentropy.dataset.summary_statistics import SummaryStatistics + from gentropy.method.l2g.feature_factory import L2GFeatureInputLoader class StudyLocusQualityCheck(Enum): @@ -572,6 +574,28 @@ def neglog_pvalue(self: StudyLocus) -> Column: self.df.pValueExponent, ) + def build_feature_matrix( + self: StudyLocus, + features_list: list[str], + features_input_loader: L2GFeatureInputLoader, + ) -> L2GFeatureMatrix: + """Returns the feature matrix for a StudyLocus. + + Args: + features_list (list[str]): List of features to include in the feature matrix. + features_input_loader (L2GFeatureInputLoader): Feature input loader to use. + + Returns: + L2GFeatureMatrix: Feature matrix for this study-locus. + """ + from gentropy.dataset.l2g_feature_matrix import L2GFeatureMatrix + + return L2GFeatureMatrix.from_features_list( + self, + features_list, + features_input_loader, + ) + def annotate_credible_sets(self: StudyLocus) -> StudyLocus: """Annotate study-locus dataset with credible set flags. diff --git a/src/gentropy/l2g.py b/src/gentropy/l2g.py index a8c703242..2914443a7 100644 --- a/src/gentropy/l2g.py +++ b/src/gentropy/l2g.py @@ -234,12 +234,8 @@ def _generate_feature_matrix(self, write_feature_matrix: bool) -> L2GFeatureMatr interactions=self.interactions, ) - # TODO: Should StudyLocus and GoldStandard have an `annotate_w_features` method? Yes - fm = L2GFeatureMatrix.from_features_list( - self.session, - self.credible_set, - self.features_list, - self.features_input_loader, + fm = self.credible_set.build_feature_matrix( + self.features_list, self.features_input_loader ) if write_feature_matrix: if not self.feature_matrix_path: @@ -249,16 +245,7 @@ def _generate_feature_matrix(self, write_feature_matrix: bool) -> L2GFeatureMatr ) return ( - L2GFeatureMatrix( - _df=fm._df.join( - f.broadcast( - gold_standards.df.drop("variantId", "studyId", "sources") - ), - on=["studyLocusId", "geneId"], - how="inner", - ), - with_gold_standard=True, # goldStandardSet col is there after joining with the GS - ) + gold_standards.build_feature_matrix(fm) .fill_na() .select_features(self.features_list) ) diff --git a/src/gentropy/method/l2g/feature_factory.py b/src/gentropy/method/l2g/feature_factory.py index 94a15b480..3c52dd946 100644 --- a/src/gentropy/method/l2g/feature_factory.py +++ b/src/gentropy/method/l2g/feature_factory.py @@ -6,7 +6,6 @@ import pyspark.sql.functions as f -from gentropy.common.session import Session from gentropy.common.spark_helpers import convert_from_wide_to_long from gentropy.dataset.l2g_feature import L2GFeature from gentropy.dataset.l2g_gold_standard import L2GGoldStandard @@ -166,13 +165,11 @@ def __init__( def generate_features( self: FeatureFactory, - session: Session, features_input_loader: L2GFeatureInputLoader, ) -> list[L2GFeature]: """Generates a feature matrix by reading an object with instructions on how to create the features. Args: - session (Session): session object features_input_loader (L2GFeatureInputLoader): object with required features dependencies. Returns: From b4a86a14fa1d3dd5a6afa23692a0e53042fb7853 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Irene=20L=C3=B3pez?= Date: Tue, 10 Sep 2024 19:37:55 +0200 Subject: [PATCH 17/48] feat: coloc logic prototype --- src/gentropy/dataset/l2g_feature.py | 6 +- src/gentropy/method/l2g/feature_factory.py | 149 ++++++++++++++++++++- 2 files changed, 148 insertions(+), 7 deletions(-) diff --git a/src/gentropy/dataset/l2g_feature.py b/src/gentropy/dataset/l2g_feature.py index 1572fe542..765bb5d55 100644 --- a/src/gentropy/dataset/l2g_feature.py +++ b/src/gentropy/dataset/l2g_feature.py @@ -22,17 +22,17 @@ class L2GFeature(Dataset, ABC): def __post_init__( self: L2GFeature, - feature_dependency: Any = None, + feature_dependency_type: Any = None, credible_set: StudyLocus | None = None, ) -> None: """Initializes a L2GFeature dataset. Args: - feature_dependency (Any): The dependency that the L2GFeature dataset depends on. Defaults to None. + feature_dependency_type (Any): The dependency that the L2GFeature dataset depends on. Defaults to None. credible_set (StudyLocus | None): The credible set that the L2GFeature dataset is based on. Defaults to None. """ super().__post_init__() - self.feature_dependency = feature_dependency + self.feature_dependency_type = feature_dependency_type self.credible_set = credible_set @classmethod diff --git a/src/gentropy/method/l2g/feature_factory.py b/src/gentropy/method/l2g/feature_factory.py index 3c52dd946..df72fc4ef 100644 --- a/src/gentropy/method/l2g/feature_factory.py +++ b/src/gentropy/method/l2g/feature_factory.py @@ -6,11 +6,16 @@ import pyspark.sql.functions as f -from gentropy.common.spark_helpers import convert_from_wide_to_long +from gentropy.common.spark_helpers import ( + convert_from_wide_to_long, + get_record_with_maximum_value, +) +from gentropy.dataset.colocalisation import Colocalisation from gentropy.dataset.l2g_feature import L2GFeature from gentropy.dataset.l2g_gold_standard import L2GGoldStandard from gentropy.dataset.study_locus import StudyLocus from gentropy.dataset.v2g import V2G +from gentropy.method.colocalisation import ECaviar class L2GFeatureInputLoader: @@ -59,6 +64,140 @@ def __repr__(self) -> str: return repr(self.input_dependencies) +class EqtlColocClppMaximumFeature(L2GFeature): + """Max CLPP for each (study, locus, gene) aggregating over all eQTLs.""" + + feature_dependency_type = Colocalisation + feature_name = "eQtlColocClppMaximum" + + @classmethod + def compute( + cls: type[EqtlColocClppMaximumFeature], + study_loci_to_annotate: StudyLocus | L2GGoldStandard, + feature_dependency: Colocalisation, + ) -> EqtlColocClppMaximumFeature: + """Computes the feature. + + Args: + study_loci_to_annotate (StudyLocus | L2GGoldStandard): The dataset containing study loci that will be used for annotation + feature_dependency (Colocalisation): Dataset with the colocalisation results + + Returns: + EqtlColocClppMaximumFeature: Feature dataset + """ + qtl_type = "eqtl" + + ecaviar_results = feature_dependency.filter( + f.col("colocalisationMethod") == ECaviar.METHOD_NAME + ) + ecaviar_metric = ECaviar.METHOD_METRIC + + # From here all code is common + qtl_specific_study_loci = study_loci_to_annotate.filter( + f.col("studyType") == qtl_type + ) + colocalising_study_loci = ( + ecaviar_results.df.join( + f.broadcast(study_loci_to_annotate.df.select("studyLocusId")), + on="studyLocusId", + ) + # filter out gwas loci on the right side + .join( + f.broadcast( + qtl_specific_study_loci.df.selectExpr( + "studyLocusId as rightStudyLocusId" + ) + ), + on="rightStudyLocusId", + ) + ) + agg_expr = get_record_with_maximum_value( + colocalising_study_loci, + ["studyLocusId", "geneId"], + ecaviar_metric, + ).selectExpr( + "studyLocusId", + "geneId", + f"{ecaviar_metric} as {cls.feature_name}", + ) + return cls( + _df=convert_from_wide_to_long( + agg_expr, + id_vars=("studyLocusId", "geneId"), + var_name="featureName", + value_name="featureValue", + ), + _schema=cls.get_schema(), + ) + + +class PqtlColocClppMaximumFeature(L2GFeature): + """Max CLPP for each (study, locus, gene) aggregating over all pQTLs.""" + + feature_dependency_type = Colocalisation + feature_name = "pQtlColocClppMaximum" + + @classmethod + def compute( + cls: type[PqtlColocClppMaximumFeature], + study_loci_to_annotate: StudyLocus | L2GGoldStandard, + feature_dependency: Colocalisation, + ) -> PqtlColocClppMaximumFeature: + """Computes the feature. + + Args: + study_loci_to_annotate (StudyLocus | L2GGoldStandard): The dataset containing study loci that will be used for annotation + feature_dependency (Colocalisation): Dataset with the colocalisation results + + Returns: + PqtlColocClppMaximumFeature: Feature dataset + """ + qtl_type = "pqtl" + + ecaviar_results = feature_dependency.filter( + f.col("colocalisationMethod") == ECaviar.METHOD_NAME + ) + ecaviar_metric = ECaviar.METHOD_METRIC + + # From here all code is common + qtl_specific_study_loci = study_loci_to_annotate.filter( + f.col("studyType") == qtl_type + ) + colocalising_study_loci = ( + ecaviar_results.df.join( + f.broadcast(study_loci_to_annotate.df.select("studyLocusId")), + on="studyLocusId", + ) + # filter out gwas loci on the right side + .join( + f.broadcast( + qtl_specific_study_loci.df.selectExpr( + "studyLocusId as rightStudyLocusId" + ) + ), + on="rightStudyLocusId", + ) + ) + agg_expr = get_record_with_maximum_value( + colocalising_study_loci, + ["studyLocusId", "geneId"], + ecaviar_metric, + ).selectExpr( + "studyLocusId", + "geneId", + f"{ecaviar_metric} as {cls.feature_name}", + ) + return cls( + _df=convert_from_wide_to_long( + agg_expr, + id_vars=("studyLocusId", "geneId"), + var_name="featureName", + value_name="featureValue", + ), + _schema=cls.get_schema(), + ) + + class DistanceTssMinimumFeature(L2GFeature): """Minimum distance of all tagging variants to gene TSS.""" @@ -87,7 +226,7 @@ class DistanceTssMeanFeature(L2GFeature): """Average distance of all tagging variants to gene TSS.""" fill_na_value = 500_000 - feature_dependency = V2G + feature_dependency_type = V2G @classmethod def compute( @@ -204,8 +343,10 @@ def compute_feature( """ # Extract feature class and dependency type feature_cls = self.feature_mapper[feature_name] - feature_input_type = feature_cls.feature_dependency + feature_dependency_type = feature_cls.feature_dependency_type return feature_cls.compute( study_loci_to_annotate=self.study_loci_to_annotate, - feature_dependency=features_input_loader.get_dependency(feature_input_type), + feature_dependency=features_input_loader.get_dependency( + feature_dependency_type + ), ) From 0b09193ad5a14c18ddb587cd4a63b46f077558c4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Irene=20L=C3=B3pez?= Date: Wed, 11 Sep 2024 15:28:04 +0100 Subject: [PATCH 18/48] feat(l2g): filter non gwas credible sets at the start of the step --- src/gentropy/dataset/l2g_prediction.py | 29 ++++++++++---------------- src/gentropy/l2g.py | 2 +- 2 files changed, 12 insertions(+), 19 deletions(-) diff --git a/src/gentropy/dataset/l2g_prediction.py b/src/gentropy/dataset/l2g_prediction.py index 86210097c..97e58f526 100644 --- a/src/gentropy/dataset/l2g_prediction.py +++ b/src/gentropy/dataset/l2g_prediction.py @@ -9,7 +9,6 @@ from gentropy.common.session import Session from gentropy.dataset.dataset import Dataset from gentropy.dataset.l2g_feature_matrix import L2GFeatureMatrix -from gentropy.dataset.study_index import StudyIndex from gentropy.dataset.study_locus import StudyLocus from gentropy.method.l2g.feature_factory import L2GFeatureInputLoader from gentropy.method.l2g.model import LocusToGeneModel @@ -51,7 +50,7 @@ def from_credible_set( Args: session (Session): Session object that contains the Spark session - credible_set (StudyLocus): Credible set dataset + credible_set (StudyLocus): Dataset containing credible sets from GWAS only features_list (list[str]): List of features to use for the model features_input_loader (L2GFeatureInputLoader): Loader with all feature dependencies model_path (str | None): Path to the model file. It can be either in the filesystem or the name on the Hugging Face Hub (in the form of username/repo_name). @@ -70,20 +69,14 @@ def from_credible_set( l2g_model = LocusToGeneModel.load_from_disk(model_path) # Prepare data - fm = L2GFeatureMatrix.from_features_list( - study_loci_to_annotate=credible_set, - features_list=features_list, - features_input_loader=features_input_loader, - ).fill_na() + fm = ( + L2GFeatureMatrix.from_features_list( + study_loci_to_annotate=credible_set, + features_list=features_list, + features_input_loader=features_input_loader, + ) + .fill_na() + .select_features(features_list) + ) - gwas_fm = L2GFeatureMatrix( - _df=( - fm._df.join( - credible_set.filter_by_study_type( - "gwas", features_input_loader.get_dependency(StudyIndex) - ).df.select("studyLocusId"), - on="studyLocusId", - ) - ), - ).select_features(features_list) - return l2g_model.predict(gwas_fm, session) + return l2g_model.predict(fm, session) diff --git a/src/gentropy/l2g.py b/src/gentropy/l2g.py index 2914443a7..7ce77eaa1 100644 --- a/src/gentropy/l2g.py +++ b/src/gentropy/l2g.py @@ -90,7 +90,7 @@ def __init__( # Load common inputs self.credible_set = StudyLocus.from_parquet( session, credible_set_path, recursiveFileLookup=True - ) + ).filter(f.col("studyType") == "gwas") self.studies = ( StudyIndex.from_parquet(session, study_index_path, recursiveFileLookup=True) if study_index_path From a60095bbb9ecfc5f1830130c43fd06033cb6c772 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Irene=20L=C3=B3pez?= Date: Fri, 13 Sep 2024 11:49:55 +0100 Subject: [PATCH 19/48] feat: rewrite colocalisation feature factory --- config/step/ot_locus_to_gene_train.yaml | 14 +- src/gentropy/dataset/colocalisation.py | 65 +++ .../datasource/eqtl_catalogue/study_index.py | 20 +- src/gentropy/method/l2g/feature_factory.py | 373 ++++++++++++++---- 4 files changed, 373 insertions(+), 99 deletions(-) diff --git a/config/step/ot_locus_to_gene_train.yaml b/config/step/ot_locus_to_gene_train.yaml index e1504fff0..181e1303b 100644 --- a/config/step/ot_locus_to_gene_train.yaml +++ b/config/step/ot_locus_to_gene_train.yaml @@ -21,6 +21,16 @@ hyperparameters: download_from_hub: true features_list: # average distance of all tagging variants to gene TSS - - distanceTssMean + # - distanceTssMean # minimum distance of all tagging variants to gene TSS - - distanceTssMinimum + # - distanceTssMinimum + # max CLPP for each (study, locus, gene) aggregating over a specific qtl type + - eQtlColocClppMaximum + - pQtlColocClppMaximum + - sQtlColocClppMaximum + - tuQtlColocClppMaximum + # max H4 for each (study, locus, gene) aggregating over a specific qtl type + - eQtlColocH4Maximum + - pQtlColocH4Maximum + - sQtlColocH4Maximum + - tuQtlColocH4Maximum diff --git a/src/gentropy/dataset/colocalisation.py b/src/gentropy/dataset/colocalisation.py index e72543cb2..171dc4789 100644 --- a/src/gentropy/dataset/colocalisation.py +++ b/src/gentropy/dataset/colocalisation.py @@ -1,15 +1,26 @@ """Colocalisation dataset.""" + from __future__ import annotations from dataclasses import dataclass from typing import TYPE_CHECKING +import pyspark.sql.functions as f + from gentropy.common.schemas import parse_spark_schema +from gentropy.common.spark_helpers import get_record_with_maximum_value from gentropy.dataset.dataset import Dataset +from gentropy.datasource.eqtl_catalogue.study_index import EqtlCatalogueStudyIndex if TYPE_CHECKING: + from pyspark.sql import DataFrame from pyspark.sql.types import StructType + from gentropy.dataset.l2g_gold_standard import L2GGoldStandard + from gentropy.dataset.study_locus import StudyLocus + +from functools import reduce + @dataclass class Colocalisation(Dataset): @@ -23,3 +34,57 @@ def get_schema(cls: type[Colocalisation]) -> StructType: StructType: Schema for the Colocalisation dataset """ return parse_spark_schema("colocalisation.json") + + def extract_maximum_coloc_probability_per_region_and_gene( + self: Colocalisation, + study_loci: StudyLocus | L2GGoldStandard, + filter_by_colocalisation_method: str, + filter_by_qtl: str | None, + ) -> DataFrame: + """Get maximum colocalisation probability for a (studyLocus, gene) window. + + Args: + study_loci (StudyLocus | L2GGoldStandard): Dataset containing study loci to filter the colocalisation dataset on and the geneId linked to the region + filter_by_colocalisation_method (str): optional filter to apply on the colocalisation dataset + filter_by_qtl (str | None): optional filter to apply on the colocalisation dataset + + Returns: + DataFrame: table with the maximum colocalisation scores for the provided study loci + + Raises: + ValueError: if filter_by_qtl is not in the list of valid QTL types + ValueError: if filter_by_colocalisation_method is not in the list of valid colocalisation methods + """ + from gentropy.colocalisation import ColocalisationStep + + valid_qtls = list(EqtlCatalogueStudyIndex.method_to_study_type_mapping.values()) + if filter_by_qtl and filter_by_qtl not in valid_qtls: + raise ValueError(f"There are no studies with QTL type {filter_by_qtl}") + + if filter_by_colocalisation_method not in [ + "ECaviar", + "COLOC", + ]: # TODO: Write helper class to retrieve coloc method names + raise ValueError( + f"Colocalisation method {filter_by_colocalisation_method} is not supported." + ) + + method_colocalisation_metric = ColocalisationStep._get_colocalisation_class( + filter_by_colocalisation_method + ).METHOD_METRIC # type: ignore + + coloc_filtering_expr = [ + (f.col("rightStudyType") != "gwas"), + f.col("colocalisationMethod") == filter_by_colocalisation_method, + ] + if filter_by_qtl: + coloc_filtering_expr.append(f.col("rightStudyType") == filter_by_qtl) + + return get_record_with_maximum_value( + # Filter coloc dataset based on method and qtl type + self.filter(reduce(lambda a, b: a & b, coloc_filtering_expr)) + # Join with study loci to get geneId + .df.join(study_loci.df.select("studyLocusId", "geneId"), "studyLocusId"), + ["studyLocusId", "geneId"], + method_colocalisation_metric, + ) diff --git a/src/gentropy/datasource/eqtl_catalogue/study_index.py b/src/gentropy/datasource/eqtl_catalogue/study_index.py index 6add70ffb..d284eb781 100644 --- a/src/gentropy/datasource/eqtl_catalogue/study_index.py +++ b/src/gentropy/datasource/eqtl_catalogue/study_index.py @@ -45,6 +45,15 @@ class EqtlCatalogueStudyIndex: ] ) raw_studies_metadata_path = "https://raw.githubusercontent.com/eQTL-Catalogue/eQTL-Catalogue-resources/092e01a9601feb404f1c88f86311b43b907a88f6/data_tables/dataset_metadata_upcoming.tsv" + method_to_study_type_mapping = { + "ge": "eqtl", + "exon": "eqtl", + "tx": "eqtl", + "microarray": "eqtl", + "leafcutter": "sqtl", + "aptamer": "pqtl", + "txrev": "tuqtl", + } @classmethod def _identify_study_type( @@ -76,17 +85,8 @@ def _identify_study_type( +------------+---------+----------+ """ - method_to_study_type_mapping = { - "ge": "eqtl", - "exon": "eqtl", - "tx": "eqtl", - "microarray": "eqtl", - "leafcutter": "sqtl", - "aptamer": "pqtl", - "txrev": "tuqtl", - } qtl_type_mapping = f.create_map( - *[f.lit(x) for x in chain(*method_to_study_type_mapping.items())] + *[f.lit(x) for x in chain(*cls.method_to_study_type_mapping.items())] )[quantification_method_col] return f.when( biosample_col.startswith("CL"), f.concat(f.lit("sc"), qtl_type_mapping) diff --git a/src/gentropy/method/l2g/feature_factory.py b/src/gentropy/method/l2g/feature_factory.py index df72fc4ef..afd783667 100644 --- a/src/gentropy/method/l2g/feature_factory.py +++ b/src/gentropy/method/l2g/feature_factory.py @@ -2,20 +2,19 @@ from __future__ import annotations -from typing import Any, Iterator, Mapping +from typing import TYPE_CHECKING, Any, Iterator, Mapping import pyspark.sql.functions as f -from gentropy.common.spark_helpers import ( - convert_from_wide_to_long, - get_record_with_maximum_value, -) +from gentropy.common.spark_helpers import convert_from_wide_to_long from gentropy.dataset.colocalisation import Colocalisation from gentropy.dataset.l2g_feature import L2GFeature from gentropy.dataset.l2g_gold_standard import L2GGoldStandard from gentropy.dataset.study_locus import StudyLocus from gentropy.dataset.v2g import V2G -from gentropy.method.colocalisation import ECaviar + +if TYPE_CHECKING: + from pyspark.sql import DataFrame class L2GFeatureInputLoader: @@ -64,7 +63,42 @@ def __repr__(self) -> str: return repr(self.input_dependencies) -class EqtlColocClppMaximumFeature(L2GFeature): +def _common_colocalisation_feature_logic( + feature_dependency: Colocalisation, + study_loci_to_annotate: StudyLocus | L2GGoldStandard, + colocalisation_method: str, + colocalisation_metric: str, + qtl_type: str, +) -> DataFrame: + """Wrapper to call the logic that creates a type of colocalisation features. + + Args: + feature_dependency (Colocalisation): Dataset with the colocalisation results + study_loci_to_annotate (StudyLocus | L2GGoldStandard): The dataset containing study loci that will be used for annotation + colocalisation_method (str): The colocalisation method to filter the data by + colocalisation_metric (str): The colocalisation metric to use + qtl_type (str): The type of QTL to filter the data by + + Returns: + DataFrame: Feature annotation in long format with the columns: studyLocusId, geneId, featureName, featureValue + """ + return convert_from_wide_to_long( + feature_dependency.extract_maximum_coloc_probability_per_region_and_gene( + study_loci_to_annotate, + filter_by_colocalisation_method=colocalisation_method, + filter_by_qtl=qtl_type, + ).selectExpr( + "studyLocusId", + "geneId", + f"{colocalisation_metric} as cls.featureName", + ), + id_vars=("studyLocusId", "geneId"), + var_name="featureName", + value_name="featureValue", + ) + + +class EQtlColocClppMaximumFeature(L2GFeature): """Max CLPP for each (study, locus, gene) aggregating over all eQTLs.""" feature_dependency_type = Colocalisation @@ -72,10 +106,10 @@ class EqtlColocClppMaximumFeature(L2GFeature): @classmethod def compute( - cls: type[EqtlColocClppMaximumFeature], + cls: type[EQtlColocClppMaximumFeature], study_loci_to_annotate: StudyLocus | L2GGoldStandard, feature_dependency: Colocalisation, - ) -> EqtlColocClppMaximumFeature: + ) -> EQtlColocClppMaximumFeature: """Computes the feature. Args: @@ -83,66 +117,179 @@ def compute( feature_dependency (Colocalisation): Dataset with the colocalisation results Returns: - EqtlColocClppMaximumFeature: Feature dataset + EQtlColocClppMaximumFeature: Feature dataset """ + colocalisation_method = "ECaviar" + colocalisation_metric = "clpp" qtl_type = "eqtl" - - ecaviar_results = feature_dependency.filter( - f.col("colocalisationMethod") == ECaviar.METHOD_NAME + return cls( + _df=_common_colocalisation_feature_logic( + feature_dependency, + study_loci_to_annotate, + colocalisation_method, + colocalisation_metric, + qtl_type, + ), + _schema=cls.get_schema(), ) - ecaviar_metric = ECaviar.METHOD_METRIC - # From here all code is common - qtl_specific_study_loci = study_loci_to_annotate.filter( - f.col("studyType") == qtl_type + +class PQtlColocClppMaximumFeature(L2GFeature): + """Max CLPP for each (study, locus, gene) aggregating over all pQTLs.""" + + feature_dependency_type = Colocalisation + feature_name = "pQtlColocClppMaximum" + + @classmethod + def compute( + cls: type[PQtlColocClppMaximumFeature], + study_loci_to_annotate: StudyLocus | L2GGoldStandard, + feature_dependency: Colocalisation, + ) -> PQtlColocClppMaximumFeature: + """Computes the feature. + + Args: + study_loci_to_annotate (StudyLocus | L2GGoldStandard): The dataset containing study loci that will be used for annotation + feature_dependency (Colocalisation): Dataset with the colocalisation results + + Returns: + PQtlColocClppMaximumFeature: Feature dataset + """ + colocalisation_method = "ECaviar" + colocalisation_metric = "clpp" + qtl_type = "pqtl" + return cls( + _df=_common_colocalisation_feature_logic( + feature_dependency, + study_loci_to_annotate, + colocalisation_method, + colocalisation_metric, + qtl_type, + ), + _schema=cls.get_schema(), ) - colocalising_study_loci = ( - ecaviar_results.df.join( - f.broadcast(study_loci_to_annotate.df.select("studyLocusId")), - on="studyLocusId", - ) - # filter out gwas loci on the right side - .join( - f.broadcast( - qtl_specific_study_loci.df.selectExpr( - "studyLocusId as rightStudyLocusId" - ) - ), - on="rightStudyLocusId", - ) + + +class SQtlColocClppMaximumFeature(L2GFeature): + """Max CLPP for each (study, locus, gene) aggregating over all sQTLs.""" + + feature_dependency_type = Colocalisation + feature_name = "sQtlColocClppMaximum" + + @classmethod + def compute( + cls: type[SQtlColocClppMaximumFeature], + study_loci_to_annotate: StudyLocus | L2GGoldStandard, + feature_dependency: Colocalisation, + ) -> SQtlColocClppMaximumFeature: + """Computes the feature. + + Args: + study_loci_to_annotate (StudyLocus | L2GGoldStandard): The dataset containing study loci that will be used for annotation + feature_dependency (Colocalisation): Dataset with the colocalisation results + + Returns: + SQtlColocClppMaximumFeature: Feature dataset + """ + colocalisation_method = "ECaviar" + colocalisation_metric = "clpp" + qtl_type = "sqtl" + return cls( + _df=_common_colocalisation_feature_logic( + feature_dependency, + study_loci_to_annotate, + colocalisation_method, + colocalisation_metric, + qtl_type, + ), + _schema=cls.get_schema(), ) - agg_expr = get_record_with_maximum_value( - colocalising_study_loci, - ["studyLocusId", "geneId"], - ecaviar_metric, - ).selectExpr( - "studyLocusId", - "geneId", - f"{ecaviar_metric} as {cls.feature_name}", + + +class TuQtlColocClppMaximumFeature(L2GFeature): + """Max CLPP for each (study, locus, gene) aggregating over all tuQTLs.""" + + feature_dependency_type = Colocalisation + feature_name = "tuQtlColocClppMaximum" + + @classmethod + def compute( + cls: type[TuQtlColocClppMaximumFeature], + study_loci_to_annotate: StudyLocus | L2GGoldStandard, + feature_dependency: Colocalisation, + ) -> TuQtlColocClppMaximumFeature: + """Computes the feature. + + Args: + study_loci_to_annotate (StudyLocus | L2GGoldStandard): The dataset containing study loci that will be used for annotation + feature_dependency (Colocalisation): Dataset with the colocalisation results + + Returns: + TuQtlColocClppMaximumFeature: Feature dataset + """ + colocalisation_method = "ECaviar" + colocalisation_metric = "clpp" + qtl_type = "tuqtl" + return cls( + _df=_common_colocalisation_feature_logic( + feature_dependency, + study_loci_to_annotate, + colocalisation_method, + colocalisation_metric, + qtl_type, + ), + _schema=cls.get_schema(), ) + + +class EQtlColocH4MaximumFeature(L2GFeature): + """Max CLPP for each (study, locus, gene) aggregating over all eQTLs.""" + + feature_dependency_type = Colocalisation + feature_name = "eQtlColocH4Maximum" + + @classmethod + def compute( + cls: type[EQtlColocH4MaximumFeature], + study_loci_to_annotate: StudyLocus | L2GGoldStandard, + feature_dependency: Colocalisation, + ) -> EQtlColocH4MaximumFeature: + """Computes the feature. + + Args: + study_loci_to_annotate (StudyLocus | L2GGoldStandard): The dataset containing study loci that will be used for annotation + feature_dependency (Colocalisation): Dataset with the colocalisation results + + Returns: + EQtlColocH4MaximumFeature: Feature dataset + """ + colocalisation_method = "COLOC" + colocalisation_metric = "h4" + qtl_type = "eqtl" return cls( - _df=convert_from_wide_to_long( - agg_expr, - id_vars=("studyLocusId", "geneId"), - var_name="featureName", - value_name="featureValue", + _df=_common_colocalisation_feature_logic( + feature_dependency, + study_loci_to_annotate, + colocalisation_method, + colocalisation_metric, + qtl_type, ), _schema=cls.get_schema(), ) -class PqtlColocClppMaximumFeature(L2GFeature): +class PQtlColocH4MaximumFeature(L2GFeature): """Max CLPP for each (study, locus, gene) aggregating over all pQTLs.""" feature_dependency_type = Colocalisation - feature_name = "pQtlColocClppMaximum" + feature_name = "pQtlColocH4Maximum" @classmethod def compute( - cls: type[PqtlColocClppMaximumFeature], + cls: type[PQtlColocH4MaximumFeature], study_loci_to_annotate: StudyLocus | L2GGoldStandard, feature_dependency: Colocalisation, - ) -> PqtlColocClppMaximumFeature: + ) -> PQtlColocH4MaximumFeature: """Computes the feature. Args: @@ -150,49 +297,90 @@ def compute( feature_dependency (Colocalisation): Dataset with the colocalisation results Returns: - PqtlColocClppMaximumFeature: Feature dataset + PQtlColocH4MaximumFeature: Feature dataset """ + colocalisation_method = "COLOC" + colocalisation_metric = "h4" qtl_type = "pqtl" - - ecaviar_results = feature_dependency.filter( - f.col("colocalisationMethod") == ECaviar.METHOD_NAME + return cls( + _df=_common_colocalisation_feature_logic( + feature_dependency, + study_loci_to_annotate, + colocalisation_method, + colocalisation_metric, + qtl_type, + ), + _schema=cls.get_schema(), ) - ecaviar_metric = ECaviar.METHOD_METRIC - # From here all code is common - qtl_specific_study_loci = study_loci_to_annotate.filter( - f.col("studyType") == qtl_type - ) - colocalising_study_loci = ( - ecaviar_results.df.join( - f.broadcast(study_loci_to_annotate.df.select("studyLocusId")), - on="studyLocusId", - ) - # filter out gwas loci on the right side - .join( - f.broadcast( - qtl_specific_study_loci.df.selectExpr( - "studyLocusId as rightStudyLocusId" - ) - ), - on="rightStudyLocusId", - ) - ) - agg_expr = get_record_with_maximum_value( - colocalising_study_loci, - ["studyLocusId", "geneId"], - ecaviar_metric, - ).selectExpr( - "studyLocusId", - "geneId", - f"{ecaviar_metric} as {cls.feature_name}", + +class SQtlColocH4MaximumFeature(L2GFeature): + """Max CLPP for each (study, locus, gene) aggregating over all sQTLs.""" + + feature_dependency_type = Colocalisation + feature_name = "sQtlColocH4Maximum" + + @classmethod + def compute( + cls: type[SQtlColocH4MaximumFeature], + study_loci_to_annotate: StudyLocus | L2GGoldStandard, + feature_dependency: Colocalisation, + ) -> SQtlColocH4MaximumFeature: + """Computes the feature. + + Args: + study_loci_to_annotate (StudyLocus | L2GGoldStandard): The dataset containing study loci that will be used for annotation + feature_dependency (Colocalisation): Dataset with the colocalisation results + + Returns: + SQtlColocH4MaximumFeature: Feature dataset + """ + colocalisation_method = "COLOC" + colocalisation_metric = "h4" + qtl_type = "sqtl" + return cls( + _df=_common_colocalisation_feature_logic( + feature_dependency, + study_loci_to_annotate, + colocalisation_method, + colocalisation_metric, + qtl_type, + ), + _schema=cls.get_schema(), ) + + +class TuQtlColocH4MaximumFeature(L2GFeature): + """Max H4 for each (study, locus, gene) aggregating over all tuQTLs.""" + + feature_dependency_type = Colocalisation + feature_name = "tuQtlColocH4Maximum" + + @classmethod + def compute( + cls: type[TuQtlColocH4MaximumFeature], + study_loci_to_annotate: StudyLocus | L2GGoldStandard, + feature_dependency: Colocalisation, + ) -> TuQtlColocH4MaximumFeature: + """Computes the feature. + + Args: + study_loci_to_annotate (StudyLocus | L2GGoldStandard): The dataset containing study loci that will be used for annotation + feature_dependency (Colocalisation): Dataset with the colocalisation results + + Returns: + TuQtlColocH4MaximumFeature: Feature dataset + """ + colocalisation_method = "COLOC" + colocalisation_metric = "h4" + qtl_type = "tuqtl" return cls( - _df=convert_from_wide_to_long( - agg_expr, - id_vars=("studyLocusId", "geneId"), - var_name="featureName", - value_name="featureValue", + _df=_common_colocalisation_feature_logic( + feature_dependency, + study_loci_to_annotate, + colocalisation_method, + colocalisation_metric, + qtl_type, ), _schema=cls.get_schema(), ) @@ -223,7 +411,10 @@ def compute( class DistanceTssMeanFeature(L2GFeature): - """Average distance of all tagging variants to gene TSS.""" + """Average distance of all tagging variants to gene TSS. + + NOTE: to be rewritten taking variant index as input + """ fill_na_value = 500_000 feature_dependency_type = V2G @@ -285,7 +476,15 @@ class FeatureFactory: feature_mapper: Mapping[str, type[L2GFeature]] = { # "distanceTssMinimum": DistanceTssMinimumFeature, - "distanceTssMean": DistanceTssMeanFeature, + # "distanceTssMean": DistanceTssMeanFeature, + "eqtlColocClppMaximum": EQtlColocClppMaximumFeature, + "pqtlColocClppMaximum": PQtlColocClppMaximumFeature, + "sqtlColocClppMaximum": SQtlColocClppMaximumFeature, + "tuqtlColocClppMaximum": TuQtlColocClppMaximumFeature, + "eqtlColocH4Maximum": EQtlColocH4MaximumFeature, + "pqtlColocH4Maximum": PQtlColocH4MaximumFeature, + "sqtlColocH4Maximum": SQtlColocH4MaximumFeature, + "tuqtlColocH4Maximum": TuQtlColocH4MaximumFeature, } def __init__( From 16085ad27d10e0ec158afd3ef53a4bf7e57f75ee Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Irene=20L=C3=B3pez?= Date: Fri, 13 Sep 2024 12:10:57 +0100 Subject: [PATCH 20/48] test: add `test_colocalisation_feature_type` --- src/gentropy/dataset/colocalisation.py | 2 +- src/gentropy/method/l2g/feature_factory.py | 8 +-- .../test_feature_factory.py} | 55 +++++++++++++++---- 3 files changed, 50 insertions(+), 15 deletions(-) rename tests/gentropy/method/{test_locus_to_gene.py => test_l2g/test_feature_factory.py} (78%) diff --git a/src/gentropy/dataset/colocalisation.py b/src/gentropy/dataset/colocalisation.py index 171dc4789..226aaca99 100644 --- a/src/gentropy/dataset/colocalisation.py +++ b/src/gentropy/dataset/colocalisation.py @@ -63,7 +63,7 @@ def extract_maximum_coloc_probability_per_region_and_gene( if filter_by_colocalisation_method not in [ "ECaviar", - "COLOC", + "Coloc", ]: # TODO: Write helper class to retrieve coloc method names raise ValueError( f"Colocalisation method {filter_by_colocalisation_method} is not supported." diff --git a/src/gentropy/method/l2g/feature_factory.py b/src/gentropy/method/l2g/feature_factory.py index afd783667..67b8dd71d 100644 --- a/src/gentropy/method/l2g/feature_factory.py +++ b/src/gentropy/method/l2g/feature_factory.py @@ -263,7 +263,7 @@ def compute( Returns: EQtlColocH4MaximumFeature: Feature dataset """ - colocalisation_method = "COLOC" + colocalisation_method = "Coloc" colocalisation_metric = "h4" qtl_type = "eqtl" return cls( @@ -299,7 +299,7 @@ def compute( Returns: PQtlColocH4MaximumFeature: Feature dataset """ - colocalisation_method = "COLOC" + colocalisation_method = "Coloc" colocalisation_metric = "h4" qtl_type = "pqtl" return cls( @@ -335,7 +335,7 @@ def compute( Returns: SQtlColocH4MaximumFeature: Feature dataset """ - colocalisation_method = "COLOC" + colocalisation_method = "Coloc" colocalisation_metric = "h4" qtl_type = "sqtl" return cls( @@ -371,7 +371,7 @@ def compute( Returns: TuQtlColocH4MaximumFeature: Feature dataset """ - colocalisation_method = "COLOC" + colocalisation_method = "Coloc" colocalisation_metric = "h4" qtl_type = "tuqtl" return cls( diff --git a/tests/gentropy/method/test_locus_to_gene.py b/tests/gentropy/method/test_l2g/test_feature_factory.py similarity index 78% rename from tests/gentropy/method/test_locus_to_gene.py rename to tests/gentropy/method/test_l2g/test_feature_factory.py index 7698e99b0..8aff3dc6c 100644 --- a/tests/gentropy/method/test_locus_to_gene.py +++ b/tests/gentropy/method/test_l2g/test_feature_factory.py @@ -2,21 +2,56 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any import pytest -from sklearn.ensemble import RandomForestClassifier -from gentropy.method.l2g.model import LocusToGeneModel +from gentropy.dataset.l2g_feature import L2GFeature +from gentropy.method.l2g.feature_factory import ( + EQtlColocClppMaximumFeature, + EQtlColocH4MaximumFeature, + PQtlColocClppMaximumFeature, + PQtlColocH4MaximumFeature, + SQtlColocClppMaximumFeature, + SQtlColocH4MaximumFeature, + TuQtlColocClppMaximumFeature, + TuQtlColocH4MaximumFeature, +) if TYPE_CHECKING: - pass - - -@pytest.fixture(scope="module") -def model() -> LocusToGeneModel: - """Creates an instance of the LocusToGene class.""" - return LocusToGeneModel(model=RandomForestClassifier()) + from gentropy.dataset.colocalisation import Colocalisation + from gentropy.dataset.study_locus import StudyLocus + + +# @pytest.fixture(scope="module") +# def model() -> LocusToGeneModel: +# """Creates an instance of the LocusToGene class.""" +# return LocusToGeneModel(model=RandomForestClassifier()) + + +@pytest.mark.parametrize( + "feature_class", + [ + EQtlColocH4MaximumFeature, + PQtlColocH4MaximumFeature, + SQtlColocH4MaximumFeature, + TuQtlColocH4MaximumFeature, + EQtlColocClppMaximumFeature, + PQtlColocClppMaximumFeature, + SQtlColocClppMaximumFeature, + TuQtlColocClppMaximumFeature, + ], +) +def test_colocalisation_feature_type( + feature_class: Any, + mock_study_locus: StudyLocus, + mock_colocalisation: Colocalisation, +) -> None: + """Test that every colocalisation feature type returns a set of L2GFeatures.""" + feature_dataset = feature_class.compute( + study_loci_to_annotate=mock_study_locus, feature_dependency=mock_colocalisation + ) + assert isinstance(feature_dataset, L2GFeature) # class TestColocalisationFactory: From 7ab1ff153cc9f882221fcbfe446a5ecf04dc7f1b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Irene=20L=C3=B3pez?= Date: Fri, 13 Sep 2024 14:36:01 +0100 Subject: [PATCH 21/48] test(colocalisation): add test_extract_maximum_coloc_probability_per_region_and_gene --- src/gentropy/dataset/colocalisation.py | 61 +++++++++++++++++-- src/gentropy/method/colocalisation.py | 2 +- tests/gentropy/dataset/test_colocalisation.py | 35 +++++++++++ 3 files changed, 91 insertions(+), 7 deletions(-) diff --git a/src/gentropy/dataset/colocalisation.py b/src/gentropy/dataset/colocalisation.py index 226aaca99..3d0064ea6 100644 --- a/src/gentropy/dataset/colocalisation.py +++ b/src/gentropy/dataset/colocalisation.py @@ -17,6 +17,7 @@ from pyspark.sql.types import StructType from gentropy.dataset.l2g_gold_standard import L2GGoldStandard + from gentropy.dataset.study_index import StudyIndex from gentropy.dataset.study_locus import StudyLocus from functools import reduce @@ -38,13 +39,15 @@ def get_schema(cls: type[Colocalisation]) -> StructType: def extract_maximum_coloc_probability_per_region_and_gene( self: Colocalisation, study_loci: StudyLocus | L2GGoldStandard, + study_index: StudyIndex, filter_by_colocalisation_method: str, - filter_by_qtl: str | None, + filter_by_qtl: str | None = None, ) -> DataFrame: """Get maximum colocalisation probability for a (studyLocus, gene) window. Args: study_loci (StudyLocus | L2GGoldStandard): Dataset containing study loci to filter the colocalisation dataset on and the geneId linked to the region + study_index (StudyIndex): Study index to use to get study metadata filter_by_colocalisation_method (str): optional filter to apply on the colocalisation dataset filter_by_qtl (str | None): optional filter to apply on the colocalisation dataset @@ -74,17 +77,63 @@ def extract_maximum_coloc_probability_per_region_and_gene( ).METHOD_METRIC # type: ignore coloc_filtering_expr = [ - (f.col("rightStudyType") != "gwas"), + f.col("rightGeneId").isNull(), f.col("colocalisationMethod") == filter_by_colocalisation_method, ] if filter_by_qtl: coloc_filtering_expr.append(f.col("rightStudyType") == filter_by_qtl) + filtered_colocalisation = ( + # Bring rightStudyType and rightGeneId and filter by rows where the gene is null, + # which is equivalent to filtering studyloci from gwas on the right side + self.append_right_study_metadata( + study_loci, study_index, ["studyType", "geneId"] + ) + # it also filters based on method and qtl type + .filter(reduce(lambda a, b: a & b, coloc_filtering_expr)) + # and filters colocalisation results to only include the subset of studylocus that contains gwas studylocusid + .join( + study_loci.df.selectExpr("studyLocusId as leftStudyLocusId"), + "leftStudyLocusId", + ) + ) + return get_record_with_maximum_value( - # Filter coloc dataset based on method and qtl type - self.filter(reduce(lambda a, b: a & b, coloc_filtering_expr)) - # Join with study loci to get geneId - .df.join(study_loci.df.select("studyLocusId", "geneId"), "studyLocusId"), + filtered_colocalisation.withColumnRenamed( + "leftStudyLocusId", "studyLocusId" + ).withColumnRenamed("rightGeneId", "geneId"), ["studyLocusId", "geneId"], method_colocalisation_metric, ) + + def append_right_study_metadata( + self: Colocalisation, + study_loci: StudyLocus | L2GGoldStandard, + study_index: StudyIndex, + metadata_cols: list[str], + ) -> DataFrame: + """Appends metadata from the study in the right side of the colocalisation dataset. + + Args: + study_loci (StudyLocus | L2GGoldStandard): Dataset containing study loci that links the colocalisation dataset and the study index via the studyId + study_index (StudyIndex): Dataset containing study index that contains the metadata + metadata_cols (list[str]): List of study columns to append + + Returns: + DataFrame: Colocalisation dataset with appended metadata of the right study + """ + # TODO: make this flexible to bring metadata from the left study (2 joins) + return self.df.join( + study_loci.df.selectExpr( + "studyLocusId as rightStudyLocusId", "studyId as rightStudyId" + ), + "rightStudyLocusId", + "left", + ).join( + study_index.df.selectExpr( + "studyId as rightStudyId", + *[f"{col} as right{col[0].upper() + col[1:]}" for col in metadata_cols], + ), + "rightStudyId", + "left", + ) diff --git a/src/gentropy/method/colocalisation.py b/src/gentropy/method/colocalisation.py index 18d97fdf8..c3320f931 100644 --- a/src/gentropy/method/colocalisation.py +++ b/src/gentropy/method/colocalisation.py @@ -112,7 +112,7 @@ class Coloc: """ METHOD_NAME: str = "COLOC" - METHOD_METRIC: str = "llr" + METHOD_METRIC: str = "h4" PSEUDOCOUNT: float = 1e-10 @staticmethod diff --git a/tests/gentropy/dataset/test_colocalisation.py b/tests/gentropy/dataset/test_colocalisation.py index 1651aa2d4..c790ffd8b 100644 --- a/tests/gentropy/dataset/test_colocalisation.py +++ b/tests/gentropy/dataset/test_colocalisation.py @@ -3,8 +3,43 @@ from __future__ import annotations from gentropy.dataset.colocalisation import Colocalisation +from gentropy.dataset.study_index import StudyIndex +from gentropy.dataset.study_locus import StudyLocus def test_colocalisation_creation(mock_colocalisation: Colocalisation) -> None: """Test colocalisation creation with mock data.""" assert isinstance(mock_colocalisation, Colocalisation) + + +def test_append_right_study_metadata( + mock_colocalisation: Colocalisation, + mock_study_locus: StudyLocus, + mock_study_index: StudyIndex, + metadata_cols: list[str] | None = None, +) -> None: + """Test appending right study metadata.""" + if metadata_cols is None: + metadata_cols = ["studyType"] + expected_extra_col = ["rightStudyType"] + res_df = mock_colocalisation.append_right_study_metadata( + mock_study_locus, mock_study_index, metadata_cols + ) + for col in expected_extra_col: + assert col in res_df.columns, f"Column {col} not found in result DataFrame." + + +def test_extract_maximum_coloc_probability_per_region_and_gene( + mock_colocalisation: Colocalisation, + mock_study_locus: StudyLocus, + mock_study_index: StudyIndex, + filter_by_colocalisation_method: str | None = None, +) -> None: + """Test extracting maximum coloc probability per region and gene returns a dataframe with the correct columns: studyLocusId, geneId, h4.""" + filter_by_colocalisation_method = filter_by_colocalisation_method or "Coloc" + res_df = mock_colocalisation.extract_maximum_coloc_probability_per_region_and_gene( + mock_study_locus, mock_study_index, filter_by_colocalisation_method + ) + expected_cols = ["studyLocusId", "geneId", "h4"] + for col in expected_cols: + assert col in res_df.columns, f"Column {col} not found in result DataFrame." From e56e8ead31773aba0265d9b59d8c2ad413549c3b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Irene=20L=C3=B3pez?= Date: Fri, 13 Sep 2024 16:14:27 +0100 Subject: [PATCH 22/48] feat(L2GFeatureInputLoader): support multiple deps by passing loader as kwarg --- src/gentropy/method/l2g/feature_factory.py | 120 +++++++++++------- .../method/test_l2g/test_feature_factory.py | 11 +- 2 files changed, 81 insertions(+), 50 deletions(-) diff --git a/src/gentropy/method/l2g/feature_factory.py b/src/gentropy/method/l2g/feature_factory.py index 67b8dd71d..fb4a5ee5d 100644 --- a/src/gentropy/method/l2g/feature_factory.py +++ b/src/gentropy/method/l2g/feature_factory.py @@ -10,6 +10,7 @@ from gentropy.dataset.colocalisation import Colocalisation from gentropy.dataset.l2g_feature import L2GFeature from gentropy.dataset.l2g_gold_standard import L2GGoldStandard +from gentropy.dataset.study_index import StudyIndex from gentropy.dataset.study_locus import StudyLocus from gentropy.dataset.v2g import V2G @@ -24,33 +25,39 @@ def __init__( self, **kwargs: Any, ) -> None: - """Initializes L2GFeatureInputLoader with the provided inputs and returns loaded dependencies as a list. + """Initializes L2GFeatureInputLoader with the provided inputs and returns loaded dependencies as a dictionary. Args: **kwargs (Any): keyword arguments with the name of the dependency and the dependency itself. """ - self.input_dependencies = [v for v in kwargs.values() if v is not None] + self.input_dependencies = {k: v for k, v in kwargs.items() if v is not None} - def get_dependency(self, dependency_type: Any) -> Any: + def get_dependency_by_type( + self, dependency_type: list[Any] | Any + ) -> dict[str, Any]: """Returns the dependency that matches the provided type. Args: - dependency_type (Any): type of the dependency to return. + dependency_type (list[Any] | Any): type(s) of the dependency to return. Returns: - Any: dependency that matches the provided type. + dict[str, Any]: dictionary of dependenci(es) that match the provided type(s). """ - for dependency in self.input_dependencies: - if isinstance(dependency, dependency_type): - return dependency + if not isinstance(dependency_type, list): + dependency_type = [dependency_type] + return { + k: v + for k, v in self.input_dependencies.items() + if isinstance(v, tuple(dependency_type)) + } - def __iter__(self) -> Iterator[dict[str, Any]]: - """Make the class iterable, returning the input dependencies list. + def __iter__(self) -> Iterator[tuple[str, Any]]: + """Make the class iterable, returning an iterator over key-value pairs. Returns: - Iterator[dict[str, Any]]: list of input dependencies. + Iterator[tuple[str, Any]]: iterator over the dictionary's key-value pairs. """ - return iter(self.input_dependencies) + return iter(self.input_dependencies.items()) def __repr__(self) -> str: """Return a string representation of the input dependencies. @@ -64,33 +71,39 @@ def __repr__(self) -> str: def _common_colocalisation_feature_logic( - feature_dependency: Colocalisation, study_loci_to_annotate: StudyLocus | L2GGoldStandard, colocalisation_method: str, colocalisation_metric: str, + feature_name: str, qtl_type: str, + *, + colocalisation: Colocalisation, + study_index: StudyIndex, ) -> DataFrame: """Wrapper to call the logic that creates a type of colocalisation features. Args: - feature_dependency (Colocalisation): Dataset with the colocalisation results study_loci_to_annotate (StudyLocus | L2GGoldStandard): The dataset containing study loci that will be used for annotation colocalisation_method (str): The colocalisation method to filter the data by colocalisation_metric (str): The colocalisation metric to use + feature_name (str): The name of the feature to create qtl_type (str): The type of QTL to filter the data by + colocalisation (Colocalisation): Dataset with the colocalisation results + study_index (StudyIndex): Study index to fetch study type and gene Returns: DataFrame: Feature annotation in long format with the columns: studyLocusId, geneId, featureName, featureValue """ return convert_from_wide_to_long( - feature_dependency.extract_maximum_coloc_probability_per_region_and_gene( + colocalisation.extract_maximum_coloc_probability_per_region_and_gene( study_loci_to_annotate, + study_index, filter_by_colocalisation_method=colocalisation_method, filter_by_qtl=qtl_type, ).selectExpr( "studyLocusId", "geneId", - f"{colocalisation_metric} as cls.featureName", + f"{colocalisation_metric} as {feature_name}", ), id_vars=("studyLocusId", "geneId"), var_name="featureName", @@ -101,20 +114,20 @@ def _common_colocalisation_feature_logic( class EQtlColocClppMaximumFeature(L2GFeature): """Max CLPP for each (study, locus, gene) aggregating over all eQTLs.""" - feature_dependency_type = Colocalisation + feature_dependency_type = [Colocalisation, StudyIndex] feature_name = "eQtlColocClppMaximum" @classmethod def compute( cls: type[EQtlColocClppMaximumFeature], study_loci_to_annotate: StudyLocus | L2GGoldStandard, - feature_dependency: Colocalisation, + feature_dependency: dict[str, Any], ) -> EQtlColocClppMaximumFeature: """Computes the feature. Args: study_loci_to_annotate (StudyLocus | L2GGoldStandard): The dataset containing study loci that will be used for annotation - feature_dependency (Colocalisation): Dataset with the colocalisation results + feature_dependency (dict[str, Any]): Dictionary with the dependencies required. They are passed as keyword arguments. Returns: EQtlColocClppMaximumFeature: Feature dataset @@ -122,13 +135,15 @@ def compute( colocalisation_method = "ECaviar" colocalisation_metric = "clpp" qtl_type = "eqtl" + return cls( _df=_common_colocalisation_feature_logic( - feature_dependency, study_loci_to_annotate, colocalisation_method, colocalisation_metric, + cls.feature_name, qtl_type, + **feature_dependency, ), _schema=cls.get_schema(), ) @@ -137,20 +152,20 @@ def compute( class PQtlColocClppMaximumFeature(L2GFeature): """Max CLPP for each (study, locus, gene) aggregating over all pQTLs.""" - feature_dependency_type = Colocalisation + feature_dependency_type = [Colocalisation, StudyIndex] feature_name = "pQtlColocClppMaximum" @classmethod def compute( cls: type[PQtlColocClppMaximumFeature], study_loci_to_annotate: StudyLocus | L2GGoldStandard, - feature_dependency: Colocalisation, + feature_dependency: dict[str, Any], ) -> PQtlColocClppMaximumFeature: """Computes the feature. Args: study_loci_to_annotate (StudyLocus | L2GGoldStandard): The dataset containing study loci that will be used for annotation - feature_dependency (Colocalisation): Dataset with the colocalisation results + feature_dependency (dict[str, Any]): Dataset with the colocalisation results Returns: PQtlColocClppMaximumFeature: Feature dataset @@ -160,11 +175,12 @@ def compute( qtl_type = "pqtl" return cls( _df=_common_colocalisation_feature_logic( - feature_dependency, study_loci_to_annotate, colocalisation_method, colocalisation_metric, + cls.feature_name, qtl_type, + **feature_dependency, ), _schema=cls.get_schema(), ) @@ -173,20 +189,20 @@ def compute( class SQtlColocClppMaximumFeature(L2GFeature): """Max CLPP for each (study, locus, gene) aggregating over all sQTLs.""" - feature_dependency_type = Colocalisation + feature_dependency_type = [Colocalisation, StudyIndex] feature_name = "sQtlColocClppMaximum" @classmethod def compute( cls: type[SQtlColocClppMaximumFeature], study_loci_to_annotate: StudyLocus | L2GGoldStandard, - feature_dependency: Colocalisation, + feature_dependency: dict[str, Any], ) -> SQtlColocClppMaximumFeature: """Computes the feature. Args: study_loci_to_annotate (StudyLocus | L2GGoldStandard): The dataset containing study loci that will be used for annotation - feature_dependency (Colocalisation): Dataset with the colocalisation results + feature_dependency (dict[str, Any]): Dataset with the colocalisation results Returns: SQtlColocClppMaximumFeature: Feature dataset @@ -196,11 +212,12 @@ def compute( qtl_type = "sqtl" return cls( _df=_common_colocalisation_feature_logic( - feature_dependency, study_loci_to_annotate, colocalisation_method, colocalisation_metric, + cls.feature_name, qtl_type, + **feature_dependency, ), _schema=cls.get_schema(), ) @@ -209,20 +226,20 @@ def compute( class TuQtlColocClppMaximumFeature(L2GFeature): """Max CLPP for each (study, locus, gene) aggregating over all tuQTLs.""" - feature_dependency_type = Colocalisation + feature_dependency_type = [Colocalisation, StudyIndex] feature_name = "tuQtlColocClppMaximum" @classmethod def compute( cls: type[TuQtlColocClppMaximumFeature], study_loci_to_annotate: StudyLocus | L2GGoldStandard, - feature_dependency: Colocalisation, + feature_dependency: dict[str, Any], ) -> TuQtlColocClppMaximumFeature: """Computes the feature. Args: study_loci_to_annotate (StudyLocus | L2GGoldStandard): The dataset containing study loci that will be used for annotation - feature_dependency (Colocalisation): Dataset with the colocalisation results + feature_dependency (dict[str, Any]): Dataset with the colocalisation results Returns: TuQtlColocClppMaximumFeature: Feature dataset @@ -232,11 +249,12 @@ def compute( qtl_type = "tuqtl" return cls( _df=_common_colocalisation_feature_logic( - feature_dependency, study_loci_to_annotate, colocalisation_method, colocalisation_metric, + cls.feature_name, qtl_type, + **feature_dependency, ), _schema=cls.get_schema(), ) @@ -245,20 +263,20 @@ def compute( class EQtlColocH4MaximumFeature(L2GFeature): """Max CLPP for each (study, locus, gene) aggregating over all eQTLs.""" - feature_dependency_type = Colocalisation + feature_dependency_type = [Colocalisation, StudyIndex] feature_name = "eQtlColocH4Maximum" @classmethod def compute( cls: type[EQtlColocH4MaximumFeature], study_loci_to_annotate: StudyLocus | L2GGoldStandard, - feature_dependency: Colocalisation, + feature_dependency: dict[str, Any], ) -> EQtlColocH4MaximumFeature: """Computes the feature. Args: study_loci_to_annotate (StudyLocus | L2GGoldStandard): The dataset containing study loci that will be used for annotation - feature_dependency (Colocalisation): Dataset with the colocalisation results + feature_dependency (dict[str, Any]): Dataset with the colocalisation results Returns: EQtlColocH4MaximumFeature: Feature dataset @@ -268,11 +286,12 @@ def compute( qtl_type = "eqtl" return cls( _df=_common_colocalisation_feature_logic( - feature_dependency, study_loci_to_annotate, colocalisation_method, colocalisation_metric, + cls.feature_name, qtl_type, + **feature_dependency, ), _schema=cls.get_schema(), ) @@ -281,20 +300,20 @@ def compute( class PQtlColocH4MaximumFeature(L2GFeature): """Max CLPP for each (study, locus, gene) aggregating over all pQTLs.""" - feature_dependency_type = Colocalisation + feature_dependency_type = [Colocalisation, StudyIndex] feature_name = "pQtlColocH4Maximum" @classmethod def compute( cls: type[PQtlColocH4MaximumFeature], study_loci_to_annotate: StudyLocus | L2GGoldStandard, - feature_dependency: Colocalisation, + feature_dependency: dict[str, Any], ) -> PQtlColocH4MaximumFeature: """Computes the feature. Args: study_loci_to_annotate (StudyLocus | L2GGoldStandard): The dataset containing study loci that will be used for annotation - feature_dependency (Colocalisation): Dataset with the colocalisation results + feature_dependency (dict[str, Any]): Dataset with the colocalisation results Returns: PQtlColocH4MaximumFeature: Feature dataset @@ -304,11 +323,12 @@ def compute( qtl_type = "pqtl" return cls( _df=_common_colocalisation_feature_logic( - feature_dependency, study_loci_to_annotate, colocalisation_method, colocalisation_metric, + cls.feature_name, qtl_type, + **feature_dependency, ), _schema=cls.get_schema(), ) @@ -317,20 +337,20 @@ def compute( class SQtlColocH4MaximumFeature(L2GFeature): """Max CLPP for each (study, locus, gene) aggregating over all sQTLs.""" - feature_dependency_type = Colocalisation + feature_dependency_type = [Colocalisation, StudyIndex] feature_name = "sQtlColocH4Maximum" @classmethod def compute( cls: type[SQtlColocH4MaximumFeature], study_loci_to_annotate: StudyLocus | L2GGoldStandard, - feature_dependency: Colocalisation, + feature_dependency: dict[str, Any], ) -> SQtlColocH4MaximumFeature: """Computes the feature. Args: study_loci_to_annotate (StudyLocus | L2GGoldStandard): The dataset containing study loci that will be used for annotation - feature_dependency (Colocalisation): Dataset with the colocalisation results + feature_dependency (dict[str, Any]): Dataset with the colocalisation results Returns: SQtlColocH4MaximumFeature: Feature dataset @@ -340,11 +360,12 @@ def compute( qtl_type = "sqtl" return cls( _df=_common_colocalisation_feature_logic( - feature_dependency, study_loci_to_annotate, colocalisation_method, colocalisation_metric, + cls.feature_name, qtl_type, + **feature_dependency, ), _schema=cls.get_schema(), ) @@ -353,20 +374,20 @@ def compute( class TuQtlColocH4MaximumFeature(L2GFeature): """Max H4 for each (study, locus, gene) aggregating over all tuQTLs.""" - feature_dependency_type = Colocalisation + feature_dependency_type = [Colocalisation, StudyIndex] feature_name = "tuQtlColocH4Maximum" @classmethod def compute( cls: type[TuQtlColocH4MaximumFeature], study_loci_to_annotate: StudyLocus | L2GGoldStandard, - feature_dependency: Colocalisation, + feature_dependency: dict[str, Any], ) -> TuQtlColocH4MaximumFeature: """Computes the feature. Args: study_loci_to_annotate (StudyLocus | L2GGoldStandard): The dataset containing study loci that will be used for annotation - feature_dependency (Colocalisation): Dataset with the colocalisation results + feature_dependency (dict[str, Any]): Dataset with the colocalisation results Returns: TuQtlColocH4MaximumFeature: Feature dataset @@ -376,11 +397,12 @@ def compute( qtl_type = "tuqtl" return cls( _df=_common_colocalisation_feature_logic( - feature_dependency, study_loci_to_annotate, colocalisation_method, colocalisation_metric, + cls.feature_name, qtl_type, + **feature_dependency, ), _schema=cls.get_schema(), ) @@ -545,7 +567,7 @@ def compute_feature( feature_dependency_type = feature_cls.feature_dependency_type return feature_cls.compute( study_loci_to_annotate=self.study_loci_to_annotate, - feature_dependency=features_input_loader.get_dependency( + feature_dependency=features_input_loader.get_dependency_by_type( feature_dependency_type ), ) diff --git a/tests/gentropy/method/test_l2g/test_feature_factory.py b/tests/gentropy/method/test_l2g/test_feature_factory.py index 8aff3dc6c..b2939a902 100644 --- a/tests/gentropy/method/test_l2g/test_feature_factory.py +++ b/tests/gentropy/method/test_l2g/test_feature_factory.py @@ -10,6 +10,7 @@ from gentropy.method.l2g.feature_factory import ( EQtlColocClppMaximumFeature, EQtlColocH4MaximumFeature, + L2GFeatureInputLoader, PQtlColocClppMaximumFeature, PQtlColocH4MaximumFeature, SQtlColocClppMaximumFeature, @@ -20,6 +21,7 @@ if TYPE_CHECKING: from gentropy.dataset.colocalisation import Colocalisation + from gentropy.dataset.study_index import StudyIndex from gentropy.dataset.study_locus import StudyLocus @@ -46,10 +48,17 @@ def test_colocalisation_feature_type( feature_class: Any, mock_study_locus: StudyLocus, mock_colocalisation: Colocalisation, + mock_study_index: StudyIndex, ) -> None: """Test that every colocalisation feature type returns a set of L2GFeatures.""" + loader = L2GFeatureInputLoader( + colocalisation=mock_colocalisation, study_index=mock_study_index + ) feature_dataset = feature_class.compute( - study_loci_to_annotate=mock_study_locus, feature_dependency=mock_colocalisation + study_loci_to_annotate=mock_study_locus, + feature_dependency=loader.get_dependency_by_type( + feature_class.feature_dependency_type + ), ) assert isinstance(feature_dataset, L2GFeature) From b8525add353d6bcfbb618fc21632ca0197551e5e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Irene=20L=C3=B3pez?= Date: Fri, 13 Sep 2024 16:37:47 +0100 Subject: [PATCH 23/48] test: add integration tests `test_build_feature_matrix` --- src/gentropy/dataset/l2g_feature_matrix.py | 2 +- src/gentropy/method/l2g/feature_factory.py | 16 ++++++------- tests/gentropy/dataset/test_study_locus.py | 19 +++++++++++++++ .../open_targets/test_l2g_gold_standard.py | 23 +++++++++++++++++++ 4 files changed, 51 insertions(+), 9 deletions(-) diff --git a/src/gentropy/dataset/l2g_feature_matrix.py b/src/gentropy/dataset/l2g_feature_matrix.py index b64961ca5..6e4bb3c9e 100644 --- a/src/gentropy/dataset/l2g_feature_matrix.py +++ b/src/gentropy/dataset/l2g_feature_matrix.py @@ -6,12 +6,12 @@ from typing import TYPE_CHECKING, Type from gentropy.common.spark_helpers import convert_from_long_to_wide +from gentropy.dataset.l2g_gold_standard import L2GGoldStandard from gentropy.method.l2g.feature_factory import FeatureFactory, L2GFeatureInputLoader if TYPE_CHECKING: from pyspark.sql import DataFrame - from gentropy.dataset.l2g_gold_standard import L2GGoldStandard from gentropy.dataset.study_locus import StudyLocus diff --git a/src/gentropy/method/l2g/feature_factory.py b/src/gentropy/method/l2g/feature_factory.py index fb4a5ee5d..9495ff905 100644 --- a/src/gentropy/method/l2g/feature_factory.py +++ b/src/gentropy/method/l2g/feature_factory.py @@ -499,14 +499,14 @@ class FeatureFactory: feature_mapper: Mapping[str, type[L2GFeature]] = { # "distanceTssMinimum": DistanceTssMinimumFeature, # "distanceTssMean": DistanceTssMeanFeature, - "eqtlColocClppMaximum": EQtlColocClppMaximumFeature, - "pqtlColocClppMaximum": PQtlColocClppMaximumFeature, - "sqtlColocClppMaximum": SQtlColocClppMaximumFeature, - "tuqtlColocClppMaximum": TuQtlColocClppMaximumFeature, - "eqtlColocH4Maximum": EQtlColocH4MaximumFeature, - "pqtlColocH4Maximum": PQtlColocH4MaximumFeature, - "sqtlColocH4Maximum": SQtlColocH4MaximumFeature, - "tuqtlColocH4Maximum": TuQtlColocH4MaximumFeature, + "eQtlColocClppMaximum": EQtlColocClppMaximumFeature, + "pQtlColocClppMaximum": PQtlColocClppMaximumFeature, + "sQtlColocClppMaximum": SQtlColocClppMaximumFeature, + "tuQtlColocClppMaximum": TuQtlColocClppMaximumFeature, + "eQtlColocH4Maximum": EQtlColocH4MaximumFeature, + "pQtlColocH4Maximum": PQtlColocH4MaximumFeature, + "sQtlColocH4Maximum": SQtlColocH4MaximumFeature, + "tuQtlColocH4Maximum": TuQtlColocH4MaximumFeature, } def __init__( diff --git a/tests/gentropy/dataset/test_study_locus.py b/tests/gentropy/dataset/test_study_locus.py index 9b40796db..994ed0287 100644 --- a/tests/gentropy/dataset/test_study_locus.py +++ b/tests/gentropy/dataset/test_study_locus.py @@ -18,6 +18,8 @@ StructType, ) +from gentropy.dataset.colocalisation import Colocalisation +from gentropy.dataset.l2g_feature_matrix import L2GFeatureMatrix from gentropy.dataset.ld_index import LDIndex from gentropy.dataset.study_index import StudyIndex from gentropy.dataset.study_locus import ( @@ -27,6 +29,7 @@ ) from gentropy.dataset.study_locus_overlap import StudyLocusOverlap from gentropy.dataset.summary_statistics import SummaryStatistics +from gentropy.method.l2g.feature_factory import L2GFeatureInputLoader @pytest.mark.parametrize( @@ -684,3 +687,19 @@ def test_study_validation_correctness(self: TestStudyLocusValidation) -> None: ) .count() ) == 1 + + +def test_build_feature_matrix( + mock_study_locus: StudyLocus, + mock_colocalisation: Colocalisation, + mock_study_index: StudyIndex, +) -> None: + """Test building feature matrix with the eQtlColocH4Maximum feature.""" + features_list = ["eQtlColocH4Maximum"] + loader = L2GFeatureInputLoader( + colocalisation=mock_colocalisation, study_index=mock_study_index + ) + fm = mock_study_locus.build_feature_matrix(features_list, loader) + assert isinstance( + fm, L2GFeatureMatrix + ), "Feature matrix should be of type L2GFeatureMatrix" diff --git a/tests/gentropy/datasource/open_targets/test_l2g_gold_standard.py b/tests/gentropy/datasource/open_targets/test_l2g_gold_standard.py index 6f91d32a9..347b7ec69 100644 --- a/tests/gentropy/datasource/open_targets/test_l2g_gold_standard.py +++ b/tests/gentropy/datasource/open_targets/test_l2g_gold_standard.py @@ -7,15 +7,21 @@ import pytest from pyspark.sql import DataFrame +from gentropy.dataset.l2g_feature_matrix import L2GFeatureMatrix from gentropy.dataset.l2g_gold_standard import L2GGoldStandard +from gentropy.dataset.study_index import StudyIndex from gentropy.dataset.v2g import V2G from gentropy.datasource.open_targets.l2g_gold_standard import ( OpenTargetsL2GGoldStandard, ) +from gentropy.method.l2g.feature_factory import L2GFeatureInputLoader if TYPE_CHECKING: from pyspark.sql.session import SparkSession + from gentropy.dataset.colocalisation import Colocalisation + from gentropy.dataset.study_locus import StudyLocus + def test_open_targets_as_l2g_gold_standard( sample_l2g_gold_standard: DataFrame, @@ -104,3 +110,20 @@ def _setup(self: TestExpandGoldStandardWithNegatives, spark: SparkSession) -> No V2G(_df=sample_v2g_df, _schema=V2G.get_schema()), ) ) + + +def test_build_feature_matrix( + mock_l2g_gold_standard: L2GGoldStandard, + mock_study_locus: StudyLocus, + mock_colocalisation: Colocalisation, + mock_study_index: StudyIndex, +) -> None: + """Test building feature matrix with the eQtlColocH4Maximum feature.""" + features_list = ["eQtlColocH4Maximum"] + loader = L2GFeatureInputLoader( + colocalisation=mock_colocalisation, study_index=mock_study_index + ) + fm = mock_study_locus.build_feature_matrix(features_list, loader) + assert isinstance( + mock_l2g_gold_standard.build_feature_matrix(fm), L2GFeatureMatrix + ), "Feature matrix should be of type L2GFeatureMatrix" From e032baeb7f4c17eb1388214191385e741684d24f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Irene=20L=C3=B3pez?= Date: Tue, 17 Sep 2024 10:53:42 +0100 Subject: [PATCH 24/48] feat(variant_index): add `get_distance_to_gene` and deprecate `get_distance_from_tss` --- src/gentropy/common/spark_helpers.py | 7 ++- src/gentropy/dataset/variant_index.py | 61 +++++++------------ src/gentropy/datasource/ensembl/vep_parser.py | 6 +- tests/gentropy/dataset/test_variant_index.py | 19 +++--- 4 files changed, 41 insertions(+), 52 deletions(-) diff --git a/src/gentropy/common/spark_helpers.py b/src/gentropy/common/spark_helpers.py index 65d3ae17b..6d078e5aa 100644 --- a/src/gentropy/common/spark_helpers.py +++ b/src/gentropy/common/spark_helpers.py @@ -6,7 +6,7 @@ import sys from functools import reduce, wraps from itertools import chain -from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, Optional, TypeVar +from typing import TYPE_CHECKING, Any, Callable, Iterable, Optional, TypeVar import pyspark.sql.functions as f import pyspark.sql.types as t @@ -429,14 +429,15 @@ def order_array_of_structs_by_two_fields( """ ) -def map_column_by_dictionary(col: Column, mapping_dict: Dict[str, str]) -> Column: + +def map_column_by_dictionary(col: Column, mapping_dict: dict[str, str]) -> Column: """Map column values to dictionary values by key. Missing consequence label will be converted to None, unmapped consequences will be mapped as None. Args: col (Column): Column containing labels to map. - mapping_dict (Dict[str, str]): Dictionary with mapping key/value pairs. + mapping_dict (dict[str, str]): Dictionary with mapping key/value pairs. Returns: Column: Column with mapped values. diff --git a/src/gentropy/dataset/variant_index.py b/src/gentropy/dataset/variant_index.py index 1cc1eac1b..820e42f01 100644 --- a/src/gentropy/dataset/variant_index.py +++ b/src/gentropy/dataset/variant_index.py @@ -10,7 +10,6 @@ from gentropy.common.schemas import parse_spark_schema from gentropy.common.spark_helpers import ( get_record_with_maximum_value, - normalise_column, rename_all_columns, safe_array_union, ) @@ -251,52 +250,36 @@ def get_transcript_consequence_df( ) return transript_consequences - def get_distance_to_tss( + def get_distance_to_gene( self: VariantIndex, - gene_index: GeneIndex, + *, + distance_type: str = "distanceFromTss", max_distance: int = 500_000, - ) -> V2G: - """Extracts variant to gene assignments for variants falling within a window of a gene's TSS. + ) -> DataFrame: + """Extracts variant to gene assignments for variants falling within a window of a gene's TSS or footprint. Args: - gene_index (GeneIndex): A gene index to filter by. - max_distance (int): The maximum distance from the TSS to consider. Defaults to 500_000. + distance_type (str): The type of distance to use. Can be "distanceFromTss" or "distanceFromFootprint". Defaults to "distanceFromTss". + max_distance (int): The maximum distance to consider. Defaults to 500_000. Returns: - V2G: variant to gene assignments with their distance to the TSS + DataFrame: A dataframe with the distance between a variant and a gene's TSS or footprint. + + Raises: + ValueError: Invalid distance type. """ - return V2G( - _df=( - self.df.alias("variant") - .join( - f.broadcast(gene_index.locations_lut()).alias("gene"), - on=[ - f.col("variant.chromosome") == f.col("gene.chromosome"), - f.abs(f.col("variant.position") - f.col("gene.tss")) - <= max_distance, - ], - how="inner", + if distance_type in {"distanceFromTss", "distanceFromFootprint"}: + return ( + self.df.select( + "variantId", f.explode("transcriptConsequences").alias("tc") ) - .withColumn( - "distance", f.abs(f.col("variant.position") - f.col("gene.tss")) - ) - .withColumn( - "inverse_distance", - max_distance - f.col("distance"), - ) - .transform(lambda df: normalise_column(df, "inverse_distance", "score")) - .select( - "variantId", - f.col("variant.chromosome").alias("chromosome"), - "distance", - "geneId", - "score", - f.lit("distance").alias("datatypeId"), - f.lit("canonical_tss").alias("datasourceId"), - ) - ), - _schema=V2G.get_schema(), - ) + .select("variantId", "tc.targetId", f"tc.{distance_type}") + .filter(f.col(distance_type) <= max_distance) + ) + else: + raise ValueError( + f"Invalid distance_type: {distance_type}. Must be 'distanceFromTss' or 'distanceFromFootprint'." + ) def get_plof_v2g(self: VariantIndex, gene_index: GeneIndex) -> V2G: """Creates a dataset with variant to gene assignments with a flag indicating if the variant is predicted to be a loss-of-function variant by the LOFTEE algorithm. diff --git a/src/gentropy/datasource/ensembl/vep_parser.py b/src/gentropy/datasource/ensembl/vep_parser.py index c7ee05d13..c546d2c27 100644 --- a/src/gentropy/datasource/ensembl/vep_parser.py +++ b/src/gentropy/datasource/ensembl/vep_parser.py @@ -3,7 +3,7 @@ from __future__ import annotations import importlib.resources as pkg_resources -from typing import TYPE_CHECKING, List +from typing import TYPE_CHECKING import pandas as pd from pyspark.sql import SparkSession @@ -527,14 +527,14 @@ def _collect_uniprot_accessions(trembl: Column, swissprot: Column) -> Column: ) @staticmethod - def _parse_variant_location_id(vep_input_field: Column) -> List[Column]: + def _parse_variant_location_id(vep_input_field: Column) -> list[Column]: r"""Parse variant identifier, chromosome, position, reference allele and alternate allele from VEP input field. Args: vep_input_field (Column): Column containing variant vcf string used as VEP input. Returns: - List[Column]: List of columns containing chromosome, position, reference allele and alternate allele. + list[Column]: List of columns containing chromosome, position, reference allele and alternate allele. """ variant_fields = f.split(vep_input_field, r"\t") return [ diff --git a/tests/gentropy/dataset/test_variant_index.py b/tests/gentropy/dataset/test_variant_index.py index 12afba89f..16c68983b 100644 --- a/tests/gentropy/dataset/test_variant_index.py +++ b/tests/gentropy/dataset/test_variant_index.py @@ -28,13 +28,6 @@ def test_get_plof_v2g( assert isinstance(mock_variant_index.get_plof_v2g(mock_gene_index), V2G) -def test_get_distance_to_tss( - mock_variant_index: VariantIndex, mock_gene_index: GeneIndex -) -> None: - """Test get_distance_to_tss with mock variant annotation.""" - assert isinstance(mock_variant_index.get_distance_to_tss(mock_gene_index), V2G) - - class TestVariantIndex: """Collection of tests around the functionality and shape of the variant index.""" @@ -147,3 +140,15 @@ def test_rsid_column_updated(self: TestVariantIndex) -> None: .count() == 2 ) + + @pytest.mark.parametrize( + "distance_type", ["distanceFromTss", "distanceFromFootprint"] + ) + def test_get_distance_to_gene( + self: TestVariantIndex, mock_variant_index: VariantIndex, distance_type: str + ) -> None: + """Assert that the function returns a df with the requested column.""" + expected_cols = ["variantId", "targetId", distance_type] + observed = mock_variant_index.get_distance_to_gene(distance_type=distance_type) + for col in expected_cols: + assert col in observed.columns, f"Column {col} not in {observed.columns}" From 6370f672f5caeb9e0fa75d869c594d928dba6604 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Irene=20L=C3=B3pez?= Date: Tue, 17 Sep 2024 12:07:01 +0100 Subject: [PATCH 25/48] feat(variant_index): deprecate `get_most_severe_transcript_consequence` --- src/gentropy/dataset/variant_index.py | 105 +++++++------------ tests/gentropy/conftest.py | 13 +++ tests/gentropy/dataset/test_variant_index.py | 22 +++- 3 files changed, 73 insertions(+), 67 deletions(-) diff --git a/src/gentropy/dataset/variant_index.py b/src/gentropy/dataset/variant_index.py index 820e42f01..cbb8964c2 100644 --- a/src/gentropy/dataset/variant_index.py +++ b/src/gentropy/dataset/variant_index.py @@ -220,36 +220,6 @@ def filter_by_variant(self: VariantIndex, df: DataFrame) -> VariantIndex: _schema=self.schema, ) - def get_transcript_consequence_df( - self: VariantIndex, gene_index: GeneIndex | None = None - ) -> DataFrame: - """Dataframe of exploded transcript consequences. - - Optionally the trancript consequences can be reduced to the universe of a gene index. - - Args: - gene_index (GeneIndex | None): A gene index. Defaults to None. - - Returns: - DataFrame: A dataframe exploded by transcript consequences with the columns variantId, chromosome, transcriptConsequence - """ - # exploding the array removes records without VEP annotation - transript_consequences = self.df.withColumn( - "transcriptConsequence", f.explode("transcriptConsequences") - ).select( - "variantId", - "chromosome", - "position", - "transcriptConsequence", - f.col("transcriptConsequence.targetId").alias("geneId"), - ) - if gene_index: - transript_consequences = transript_consequences.join( - f.broadcast(gene_index.df), - on=["chromosome", "geneId"], - ) - return transript_consequences - def get_distance_to_gene( self: VariantIndex, *, @@ -260,7 +230,7 @@ def get_distance_to_gene( Args: distance_type (str): The type of distance to use. Can be "distanceFromTss" or "distanceFromFootprint". Defaults to "distanceFromTss". - max_distance (int): The maximum distance to consider. Defaults to 500_000. + max_distance (int): The maximum distance to consider. Defaults to 500_000, the default window size for VEP. Returns: DataFrame: A dataframe with the distance between a variant and a gene's TSS or footprint. @@ -268,17 +238,20 @@ def get_distance_to_gene( Raises: ValueError: Invalid distance type. """ - if distance_type in {"distanceFromTss", "distanceFromFootprint"}: - return ( - self.df.select( - "variantId", f.explode("transcriptConsequences").alias("tc") - ) - .select("variantId", "tc.targetId", f"tc.{distance_type}") - .filter(f.col(distance_type) <= max_distance) + if distance_type not in {"distanceFromTss", "distanceFromFootprint"}: + raise ValueError( + f"Invalid distance_type: {distance_type}. Must be 'distanceFromTss' or 'distanceFromFootprint'." ) + df = self.df.select( + "variantId", f.explode("transcriptConsequences").alias("tc") + ).select("variantId", "tc.targetId", f"tc.{distance_type}") + if max_distance == 500_000: + return df + elif max_distance < 500_000: + return df.filter(f"{distance_type} <= {max_distance}") else: raise ValueError( - f"Invalid distance_type: {distance_type}. Must be 'distanceFromTss' or 'distanceFromFootprint'." + f"max_distance must be less than 500_000. Got {max_distance}." ) def get_plof_v2g(self: VariantIndex, gene_index: GeneIndex) -> V2G: @@ -294,8 +267,9 @@ def get_plof_v2g(self: VariantIndex, gene_index: GeneIndex) -> V2G: """ return V2G( _df=( - self.get_transcript_consequence_df(gene_index) - .filter(f.col("transcriptConsequence.lofteePrediction").isNotNull()) + self.df.filter( + f.col("transcriptConsequence.lofteePrediction").isNotNull() + ) .withColumn( "isHighQualityPlof", f.when( @@ -323,46 +297,47 @@ def get_plof_v2g(self: VariantIndex, gene_index: GeneIndex) -> V2G: _schema=V2G.get_schema(), ) - def get_most_severe_transcript_consequence( + def get_most_severe_gene_consequence( self: VariantIndex, + *, vep_consequences: DataFrame, - gene_index: GeneIndex, - ) -> V2G: - """Creates a dataset with variant to gene assignments based on VEP's predicted consequence of the transcript. - - Optionally the trancript consequences can be reduced to the universe of a gene index. + ) -> DataFrame: + """Returns a dataframe with the most severe consequence for a variant/gene pair. Args: vep_consequences (DataFrame): A dataframe of VEP consequences - gene_index (GeneIndex): A gene index to filter by. Defaults to None. Returns: - V2G: High and medium severity variant to gene assignments + DataFrame: High and medium severity variant to gene assignments """ - return V2G( - _df=self.get_transcript_consequence_df(gene_index) + return ( + self.df.select("variantId", f.explode("transcriptConsequences").alias("tc")) .select( "variantId", - "chromosome", - f.col("transcriptConsequence.targetId").alias("geneId"), - f.explode( - "transcriptConsequence.variantFunctionalConsequenceIds" - ).alias("variantFunctionalConsequenceId"), - f.lit("vep").alias("datatypeId"), - f.lit("variantConsequence").alias("datasourceId"), + f.col("tc.targetId"), + f.explode(f.col("tc.variantFunctionalConsequenceIds")).alias( + "variantFunctionalConsequenceId" + ), ) .join( - f.broadcast(vep_consequences), + # TODO: make this table a project config + f.broadcast( + vep_consequences.selectExpr( + "variantFunctionalConsequenceId", "score as severityScore" + ) + ), on="variantFunctionalConsequenceId", how="inner", ) - .drop("label") - .filter(f.col("score") != 0) - # A variant can have multiple predicted consequences on a transcript, the most severe one is selected + .filter((f.col("severityScore") != 0) | (f.col("severityScore").isNull())) .transform( + # A variant can have multiple predicted consequences on a transcript, the most severe one is selected lambda df: get_record_with_maximum_value( - df, ["variantId", "geneId"], "score" + df, ["variantId", "targetId"], "severityScore" ) - ), - _schema=V2G.get_schema(), + ) + .withColumnRenamed( + "variantFunctionalConsequenceId", + "mostSevereVariantFunctionalConsequenceId", + ) ) diff --git a/tests/gentropy/conftest.py b/tests/gentropy/conftest.py index 93ee38471..903a9d5b5 100644 --- a/tests/gentropy/conftest.py +++ b/tests/gentropy/conftest.py @@ -275,6 +275,19 @@ def mock_v2g(spark: SparkSession) -> V2G: return V2G(_df=data_spec.build(), _schema=v2g_schema) +@pytest.fixture() +def mock_variant_consequence_to_score(spark: SparkSession) -> DataFrame: + """Slice of the VEP consequence to score table.""" + return spark.createDataFrame( + [ + ("SO_0001893", "transcript_ablation", 1.0), + ("SO_0001822", "inframe_deletion", 0.66), + ("SO_0001567", "stop_retained_variant", 0.33), + ], + ["variantFunctionalConsequenceId", "label", "score"], + ) + + @pytest.fixture() def mock_variant_index(spark: SparkSession) -> VariantIndex: """Mock variant index.""" diff --git a/tests/gentropy/dataset/test_variant_index.py b/tests/gentropy/dataset/test_variant_index.py index 16c68983b..d57603d37 100644 --- a/tests/gentropy/dataset/test_variant_index.py +++ b/tests/gentropy/dataset/test_variant_index.py @@ -13,7 +13,7 @@ from gentropy.dataset.variant_index import VariantIndex if TYPE_CHECKING: - from pyspark.sql import SparkSession + from pyspark.sql import DataFrame, SparkSession def test_variant_index_creation(mock_variant_index: VariantIndex) -> None: @@ -147,8 +147,26 @@ def test_rsid_column_updated(self: TestVariantIndex) -> None: def test_get_distance_to_gene( self: TestVariantIndex, mock_variant_index: VariantIndex, distance_type: str ) -> None: - """Assert that the function returns a df with the requested column.""" + """Assert that the function returns a df with the requested columns.""" expected_cols = ["variantId", "targetId", distance_type] observed = mock_variant_index.get_distance_to_gene(distance_type=distance_type) for col in expected_cols: assert col in observed.columns, f"Column {col} not in {observed.columns}" + + def test_get_most_severe_gene_consequence( + self: TestVariantIndex, + mock_variant_index: VariantIndex, + mock_variant_consequence_to_score: DataFrame, + ) -> None: + """Assert that the function returns a df with the requested columns.""" + expected_cols = [ + "variantId", + "targetId", + "mostSevereVariantFunctionalConsequenceId", + "severityScore", + ] + observed = mock_variant_index.get_most_severe_gene_consequence( + vep_consequences=mock_variant_consequence_to_score + ) + for col in expected_cols: + assert col in observed.columns, f"Column {col} not in {observed.columns}" From fae8256aa3b33ac6b3c95c851c96f4b5590ef0e0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Irene=20L=C3=B3pez?= Date: Tue, 17 Sep 2024 12:36:05 +0100 Subject: [PATCH 26/48] feat(variant_index): add `get_loftee` and deprecate `get_plof_v2g` --- src/gentropy/dataset/variant_index.py | 62 +++++++------------- tests/gentropy/dataset/test_variant_index.py | 23 +++++--- 2 files changed, 35 insertions(+), 50 deletions(-) diff --git a/src/gentropy/dataset/variant_index.py b/src/gentropy/dataset/variant_index.py index cbb8964c2..014d36559 100644 --- a/src/gentropy/dataset/variant_index.py +++ b/src/gentropy/dataset/variant_index.py @@ -14,13 +14,11 @@ safe_array_union, ) from gentropy.dataset.dataset import Dataset -from gentropy.dataset.v2g import V2G if TYPE_CHECKING: from pyspark.sql import Column, DataFrame from pyspark.sql.types import StructType - from gentropy.dataset.gene_index import GeneIndex @dataclass @@ -254,47 +252,29 @@ def get_distance_to_gene( f"max_distance must be less than 500_000. Got {max_distance}." ) - def get_plof_v2g(self: VariantIndex, gene_index: GeneIndex) -> V2G: - """Creates a dataset with variant to gene assignments with a flag indicating if the variant is predicted to be a loss-of-function variant by the LOFTEE algorithm. + def get_loftee(self: VariantIndex) -> DataFrame: + """Returns a dataframe with a flag indicating whether a variant is predicted to cause loss of function in a gene. The source of this information is the LOFTEE algorithm (https://github.com/konradjk/loftee). - Optionally the trancript consequences can be reduced to the universe of a gene index. - - Args: - gene_index (GeneIndex): A gene index to filter by. + !!! note, "This will return a filtered dataframe with only variants that have been annotated by LOFTEE." Returns: - V2G: variant to gene assignments from the LOFTEE algorithm + DataFrame: variant to gene assignments from the LOFTEE algorithm """ - return V2G( - _df=( - self.df.filter( - f.col("transcriptConsequence.lofteePrediction").isNotNull() - ) - .withColumn( - "isHighQualityPlof", - f.when( - f.col("transcriptConsequence.lofteePrediction") == "HC", True - ).when( - f.col("transcriptConsequence.lofteePrediction") == "LC", False - ), - ) - .withColumn( - "score", - f.when(f.col("isHighQualityPlof"), 1.0).when( - ~f.col("isHighQualityPlof"), 0 - ), - ) - .select( - "variantId", - "chromosome", - "geneId", - "isHighQualityPlof", - f.col("score"), - f.lit("vep").alias("datatypeId"), - f.lit("loftee").alias("datasourceId"), - ) - ), - _schema=V2G.get_schema(), + return ( + self.df.select("variantId", f.explode("transcriptConsequences").alias("tc")) + .filter(f.col("tc.lofteePrediction").isNotNull()) + .withColumn( + "isHighQualityPlof", + f.when(f.col("tc.lofteePrediction") == "HC", True).when( + f.col("tc.lofteePrediction") == "LC", False + ), + ) + .select( + "variantId", + f.col("tc.targetId"), + f.col("tc.lofteePrediction"), + "isHighQualityPlof", + ) ) def get_most_severe_gene_consequence( @@ -308,7 +288,7 @@ def get_most_severe_gene_consequence( vep_consequences (DataFrame): A dataframe of VEP consequences Returns: - DataFrame: High and medium severity variant to gene assignments + DataFrame: A dataframe with the most severe consequence (plus a severity score) for a variant/gene pair """ return ( self.df.select("variantId", f.explode("transcriptConsequences").alias("tc")) @@ -329,7 +309,7 @@ def get_most_severe_gene_consequence( on="variantFunctionalConsequenceId", how="inner", ) - .filter((f.col("severityScore") != 0) | (f.col("severityScore").isNull())) + .filter(f.col("severityScore").isNull()) .transform( # A variant can have multiple predicted consequences on a transcript, the most severe one is selected lambda df: get_record_with_maximum_value( diff --git a/tests/gentropy/dataset/test_variant_index.py b/tests/gentropy/dataset/test_variant_index.py index d57603d37..29a6ef035 100644 --- a/tests/gentropy/dataset/test_variant_index.py +++ b/tests/gentropy/dataset/test_variant_index.py @@ -8,8 +8,6 @@ from pyspark.sql import functions as f from pyspark.sql import types as t -from gentropy.dataset.gene_index import GeneIndex -from gentropy.dataset.v2g import V2G from gentropy.dataset.variant_index import VariantIndex if TYPE_CHECKING: @@ -21,13 +19,6 @@ def test_variant_index_creation(mock_variant_index: VariantIndex) -> None: assert isinstance(mock_variant_index, VariantIndex) -def test_get_plof_v2g( - mock_variant_index: VariantIndex, mock_gene_index: GeneIndex -) -> None: - """Test get_plof_v2g with mock variant annotation.""" - assert isinstance(mock_variant_index.get_plof_v2g(mock_gene_index), V2G) - - class TestVariantIndex: """Collection of tests around the functionality and shape of the variant index.""" @@ -170,3 +161,17 @@ def test_get_most_severe_gene_consequence( ) for col in expected_cols: assert col in observed.columns, f"Column {col} not in {observed.columns}" + + def test_get_loftee( + self: TestVariantIndex, mock_variant_index: VariantIndex + ) -> None: + """Assert that the function returns a df with the requested columns.""" + expected_cols = [ + "variantId", + "targetId", + "lofteePrediction", + "isHighQualityPlof", + ] + observed = mock_variant_index.get_loftee() + for col in expected_cols: + assert col in observed.columns, f"Column {col} not in {observed.columns}" From 4ca943cc4c9ec9e360129336ba24065e3b55aef1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Irene=20L=C3=B3pez?= Date: Tue, 17 Sep 2024 13:45:14 +0100 Subject: [PATCH 27/48] chore: reduce v2g assesments to intervals --- src/gentropy/config.py | 19 ----------------- src/gentropy/dataset/v2g.py | 10 +-------- src/gentropy/variant_to_gene.py | 36 ++------------------------------- 3 files changed, 3 insertions(+), 62 deletions(-) diff --git a/src/gentropy/config.py b/src/gentropy/config.py index aa8e331af..620e7a12b 100644 --- a/src/gentropy/config.py +++ b/src/gentropy/config.py @@ -333,27 +333,8 @@ class VariantToGeneConfig(StepConfig): variant_index_path: str = MISSING gene_index_path: str = MISSING - vep_consequences_path: str = MISSING liftover_chain_file_path: str = MISSING liftover_max_length_difference: int = 100 - max_distance: int = 500_000 - approved_biotypes: List[str] = field( - default_factory=lambda: [ - "protein_coding", - "3prime_overlapping_ncRNA", - "antisense", - "bidirectional_promoter_lncRNA", - "IG_C_gene", - "IG_D_gene", - "IG_J_gene", - "IG_V_gene", - "lincRNA", - "macro_lncRNA", - "non_coding", - "sense_intronic", - "sense_overlapping", - ] - ) interval_sources: Dict[str, str] = field(default_factory=dict) v2g_path: str = MISSING _target_: str = "gentropy.variant_to_gene.V2GStep" diff --git a/src/gentropy/dataset/v2g.py b/src/gentropy/dataset/v2g.py index 04bad2113..b14205b17 100644 --- a/src/gentropy/dataset/v2g.py +++ b/src/gentropy/dataset/v2g.py @@ -1,11 +1,10 @@ """V2G dataset.""" + from __future__ import annotations from dataclasses import dataclass from typing import TYPE_CHECKING -import pyspark.sql.functions as f - from gentropy.common.schemas import parse_spark_schema from gentropy.dataset.dataset import Dataset @@ -42,10 +41,3 @@ def filter_by_genes(self: V2G, genes: GeneIndex) -> V2G: """ self.df = self._df.join(genes.df.select("geneId"), on="geneId", how="inner") return self - - def extract_distance_tss_minimum(self: V2G) -> None: - """Extract minimum distance to TSS.""" - self.df = self._df.filter(f.col("distance")).withColumn( - "distanceTssMinimum", - f.expr("min(distTss) OVER (PARTITION BY studyLocusId)"), - ) diff --git a/src/gentropy/variant_to_gene.py b/src/gentropy/variant_to_gene.py index cf21053d7..60c863f07 100644 --- a/src/gentropy/variant_to_gene.py +++ b/src/gentropy/variant_to_gene.py @@ -1,11 +1,9 @@ -"""Step to generate variant annotation dataset.""" +"""Step to generate variant to gene dataset.""" from __future__ import annotations from functools import reduce -from pyspark.sql import functions as f - from gentropy.common.Liftover import LiftOverSpark from gentropy.common.session import Session from gentropy.dataset.gene_index import GeneIndex @@ -20,18 +18,13 @@ class V2GStep: This step aims to generate a dataset that contains multiple pieces of evidence supporting the functional association of specific variants with genes. Some of the evidence types include: 1. Chromatin interaction experiments, e.g. Promoter Capture Hi-C (PCHi-C). - 2. In silico functional predictions, e.g. Variant Effect Predictor (VEP) from Ensembl. - 3. Distance between the variant and each gene's canonical transcription start site (TSS). Attributes: session (Session): Session object. variant_index_path (str): Input variant index path. gene_index_path (str): Input gene index path. - vep_consequences_path (str): Input VEP consequences path. liftover_chain_file_path (str): Path to GRCh37 to GRCh38 chain file. liftover_max_length_difference: Maximum length difference for liftover. - max_distance (int): Maximum distance to consider. - approved_biotypes (list[str]): List of approved biotypes. intervals (dict): Dictionary of interval sources. v2g_path (str): Output V2G path. """ @@ -41,12 +34,9 @@ def __init__( session: Session, variant_index_path: str, gene_index_path: str, - vep_consequences_path: str, liftover_chain_file_path: str, - approved_biotypes: list[str], interval_sources: dict[str, str], v2g_path: str, - max_distance: int = 500_000, liftover_max_length_difference: int = 100, ) -> None: """Run Variant-to-gene (V2G) step. @@ -55,21 +45,14 @@ def __init__( session (Session): Session object. variant_index_path (str): Input variant index path. gene_index_path (str): Input gene index path. - vep_consequences_path (str): Input VEP consequences path. liftover_chain_file_path (str): Path to GRCh37 to GRCh38 chain file. - approved_biotypes (list[str]): List of approved biotypes. interval_sources (dict[str, str]): Dictionary of interval sources. v2g_path (str): Output V2G path. - max_distance (int): Maximum distance to consider. liftover_max_length_difference (int): Maximum length difference for liftover. """ # Read gene_index = GeneIndex.from_parquet(session, gene_index_path) vi = VariantIndex.from_parquet(session, variant_index_path).persist() - # Reading VEP consequence to score table and cast the score to the right type: - vep_consequences = session.spark.read.csv( - vep_consequences_path, sep="\t", header=True - ).withColumn("score", f.col("score").cast("double")) # Transform lift = LiftOverSpark( @@ -77,10 +60,6 @@ def __init__( liftover_chain_file_path, liftover_max_length_difference, ) - gene_index_filtered = gene_index.filter_by_biotypes( - # Filter gene index by approved biotypes to define V2G gene universe - list(approved_biotypes) - ) intervals = Intervals( _df=reduce( @@ -95,19 +74,8 @@ def __init__( ), _schema=Intervals.get_schema(), ) - v2g_datasets = [ - vi.get_distance_to_tss(gene_index_filtered, max_distance), - vi.get_most_severe_transcript_consequence( - vep_consequences, gene_index_filtered - ), - vi.get_plof_v2g(gene_index_filtered), - intervals.v2g(vi), - ] v2g = V2G( - _df=reduce( - lambda x, y: x.unionByName(y, allowMissingColumns=True), - [dataset.df for dataset in v2g_datasets], - ).repartition("chromosome"), + _df=intervals.v2g(vi).df.repartition("chromosome"), _schema=V2G.get_schema(), ) From 36f88043dee72c24766189a2aee1a1665d02d6b4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Irene=20L=C3=B3pez?= Date: Tue, 17 Sep 2024 15:38:38 +0100 Subject: [PATCH 28/48] feat(feature_factory): add distance to footprint features --- src/gentropy/method/l2g/feature_factory.py | 218 ++++++++++++++---- .../method/test_l2g/test_feature_factory.py | 32 +++ 2 files changed, 204 insertions(+), 46 deletions(-) diff --git a/src/gentropy/method/l2g/feature_factory.py b/src/gentropy/method/l2g/feature_factory.py index 9495ff905..402b597a8 100644 --- a/src/gentropy/method/l2g/feature_factory.py +++ b/src/gentropy/method/l2g/feature_factory.py @@ -12,10 +12,10 @@ from gentropy.dataset.l2g_gold_standard import L2GGoldStandard from gentropy.dataset.study_index import StudyIndex from gentropy.dataset.study_locus import StudyLocus -from gentropy.dataset.v2g import V2G +from gentropy.dataset.variant_index import VariantIndex if TYPE_CHECKING: - from pyspark.sql import DataFrame + from pyspark.sql import Column, DataFrame class L2GFeatureInputLoader: @@ -408,83 +408,209 @@ def compute( ) +def _common_distance_feature_logic( + study_loci_to_annotate: StudyLocus | L2GGoldStandard, + *, + variant_index: VariantIndex, + feature_name: str, + distance_type: str, + agg_expr: Column, +) -> DataFrame: + """Computes the feature. + + Args: + study_loci_to_annotate (StudyLocus | L2GGoldStandard): The dataset containing study loci that will be used for annotation + variant_index (VariantIndex): The dataset containing distance to gene information + feature_name (str): The name of the feature + distance_type (str): The type of distance to gene + agg_expr (Column): The expression that aggregate distances into a specific way to define the feature + + Returns: + DataFrame: Feature dataset + """ + distances_dataset = variant_index.get_distance_to_gene(distance_type=distance_type) + return ( + study_loci_to_annotate.df.withColumn("variantInLocus", f.explode_outer("locus")) + .select( + "studyLocusId", + f.col("variantInLocus.variantId").alias("variantInLocusId"), + f.col("variantInLocus.posteriorProbability").alias( + "variantInLocusPosteriorProbability" + ), + ) + .join( + distances_dataset.withColumnRenamed( + "variantId", "variantInLocusId" + ).withColumnRenamed("targetId", "geneId"), + on="variantInLocusId", + how="inner", + ) + .withColumn( + "weightedDistance", + f.col(distance_type) * f.col("variantInLocusPosteriorProbability"), + ) + .groupBy("studyLocusId", "geneId") + .agg(agg_expr) + .alias(feature_name) + ) + + +class DistanceTssMeanFeature(L2GFeature): + """Average distance of all tagging variants to gene TSS.""" + + fill_na_value = 500_000 + feature_dependency_type = VariantIndex + feature_name = "distanceTssMean" + + @classmethod + def compute( + cls: type[DistanceTssMeanFeature], + study_loci_to_annotate: StudyLocus | L2GGoldStandard, + feature_dependency: dict[str, Any], + ) -> DistanceTssMeanFeature: + """Computes the feature. + + Args: + study_loci_to_annotate (StudyLocus | L2GGoldStandard): The dataset containing study loci that will be used for annotation + feature_dependency (dict[str, Any]): Dataset that contains the distance information + + Returns: + DistanceTssMeanFeature: Feature dataset + """ + agg_expr = f.mean("weightedDistance") + distance_type = "distanceFromTss" + return cls( + _df=convert_from_wide_to_long( + _common_distance_feature_logic( + study_loci_to_annotate, + feature_name=cls.feature_name, + distance_type=distance_type, + agg_expr=agg_expr, + **feature_dependency, + ), + id_vars=("studyLocusId", "geneId"), + var_name="featureName", + value_name="featureValue", + ), + _schema=cls.get_schema(), + ) + + class DistanceTssMinimumFeature(L2GFeature): """Minimum distance of all tagging variants to gene TSS.""" + fill_na_value = 500_000 + feature_dependency_type = VariantIndex + feature_name = "distanceTssMinimum" + @classmethod def compute( cls: type[DistanceTssMinimumFeature], study_loci_to_annotate: StudyLocus | L2GGoldStandard, - feature_dependency: V2G, - ) -> L2GFeature: + feature_dependency: dict[str, Any], + ) -> DistanceTssMinimumFeature: """Computes the feature. Args: study_loci_to_annotate (StudyLocus | L2GGoldStandard): The dataset containing study loci that will be used for annotation - feature_dependency (V2G): Dataset that contains the distance information + feature_dependency (dict[str, Any]): Dataset that contains the distance information Returns: - L2GFeature: Feature dataset - - Raises: - NotImplementedError: Not implemented + DistanceTssMinimumFeature: Feature dataset """ - raise NotImplementedError - + agg_expr = f.mean("weightedDistance") + distance_type = "distanceFromTss" + return cls( + _df=convert_from_wide_to_long( + _common_distance_feature_logic( + study_loci_to_annotate, + feature_name=cls.feature_name, + distance_type=distance_type, + agg_expr=agg_expr, + **feature_dependency, + ), + id_vars=("studyLocusId", "geneId"), + var_name="featureName", + value_name="featureValue", + ), + _schema=cls.get_schema(), + ) -class DistanceTssMeanFeature(L2GFeature): - """Average distance of all tagging variants to gene TSS. - NOTE: to be rewritten taking variant index as input - """ +class DistanceFootprintMeanFeature(L2GFeature): + """Average distance of all tagging variants to the footprint of a gene.""" fill_na_value = 500_000 - feature_dependency_type = V2G + feature_dependency_type = VariantIndex + feature_name = "distanceFootprintMean" @classmethod def compute( - cls: type[DistanceTssMeanFeature], + cls: type[DistanceFootprintMeanFeature], study_loci_to_annotate: StudyLocus | L2GGoldStandard, - feature_dependency: V2G, - ) -> DistanceTssMeanFeature: + feature_dependency: dict[str, Any], + ) -> DistanceFootprintMeanFeature: """Computes the feature. Args: study_loci_to_annotate (StudyLocus | L2GGoldStandard): The dataset containing study loci that will be used for annotation - feature_dependency (V2G): Dataset that contains the distance information + feature_dependency (dict[str, Any]): Dataset that contains the distance information Returns: - DistanceTssMeanFeature: Feature dataset + DistanceFootprintMeanFeature: Feature dataset """ - agg_expr = f.mean("weightedScore").alias("distanceTssMean") - # Everything but expresion is common logic - v2g = feature_dependency.df.filter(f.col("datasourceId") == "canonical_tss") - wide_df = ( - study_loci_to_annotate.df.withColumn( - "variantInLocus", f.explode_outer("locus") - ) - .select( - "studyLocusId", - f.col("variantInLocus.variantId").alias("variantInLocusId"), - f.col("variantInLocus.posteriorProbability").alias( - "variantInLocusPosteriorProbability" + agg_expr = f.mean("weightedDistance") + distance_type = "distanceFromFootprint" + return cls( + _df=convert_from_wide_to_long( + _common_distance_feature_logic( + study_loci_to_annotate, + feature_name=cls.feature_name, + distance_type=distance_type, + agg_expr=agg_expr, + **feature_dependency, ), - ) - .join( - v2g.selectExpr("variantId as variantInLocusId", "geneId", "score"), - on="variantInLocusId", - how="inner", - ) - .withColumn( - "weightedScore", - f.col("score") * f.col("variantInLocusPosteriorProbability"), - ) - .groupBy("studyLocusId", "geneId") - .agg(agg_expr) + id_vars=("studyLocusId", "geneId"), + var_name="featureName", + value_name="featureValue", + ), + _schema=cls.get_schema(), ) + + +class DistanceFootprintMinimumFeature(L2GFeature): + """Minimum distance of all tagging variants to the footprint of a gene.""" + + fill_na_value = 500_000 + feature_dependency_type = VariantIndex + feature_name = "DistanceFootprintMinimum" + + @classmethod + def compute( + cls: type[DistanceFootprintMinimumFeature], + study_loci_to_annotate: StudyLocus | L2GGoldStandard, + feature_dependency: dict[str, Any], + ) -> DistanceFootprintMinimumFeature: + """Computes the feature. + + Args: + study_loci_to_annotate (StudyLocus | L2GGoldStandard): The dataset containing study loci that will be used for annotation + feature_dependency (dict[str, Any]): Dataset that contains the distance information + + Returns: + DistanceFootprintMinimumFeature: Feature dataset + """ + agg_expr = f.mean("weightedDistance") + distance_type = "distanceFromFootprint" return cls( _df=convert_from_wide_to_long( - wide_df, + _common_distance_feature_logic( + study_loci_to_annotate, + feature_name=cls.feature_name, + distance_type=distance_type, + agg_expr=agg_expr, + **feature_dependency, + ), id_vars=("studyLocusId", "geneId"), var_name="featureName", value_name="featureValue", diff --git a/tests/gentropy/method/test_l2g/test_feature_factory.py b/tests/gentropy/method/test_l2g/test_feature_factory.py index b2939a902..9f829d046 100644 --- a/tests/gentropy/method/test_l2g/test_feature_factory.py +++ b/tests/gentropy/method/test_l2g/test_feature_factory.py @@ -7,7 +7,12 @@ import pytest from gentropy.dataset.l2g_feature import L2GFeature +from gentropy.dataset.variant_index import VariantIndex from gentropy.method.l2g.feature_factory import ( + DistanceFootprintMeanFeature, + DistanceFootprintMinimumFeature, + DistanceTssMeanFeature, + DistanceTssMinimumFeature, EQtlColocClppMaximumFeature, EQtlColocH4MaximumFeature, L2GFeatureInputLoader, @@ -63,6 +68,33 @@ def test_colocalisation_feature_type( assert isinstance(feature_dataset, L2GFeature) +@pytest.mark.parametrize( + "feature_class", + [ + DistanceTssMeanFeature, + DistanceTssMinimumFeature, + DistanceFootprintMeanFeature, + DistanceFootprintMinimumFeature, + ], +) +def test_distance_feature_type( + feature_class: Any, + mock_study_locus: StudyLocus, + mock_variant_index: VariantIndex, +) -> None: + """Test that every distance feature type returns a set of L2GFeatures.""" + loader = L2GFeatureInputLoader( + variant_index=mock_variant_index, + ) + feature_dataset = feature_class.compute( + study_loci_to_annotate=mock_study_locus, + feature_dependency=loader.get_dependency_by_type( + feature_class.feature_dependency_type + ), + ) + assert isinstance(feature_dataset, L2GFeature) + + # class TestColocalisationFactory: # """Test the ColocalisationFactory methods.""" From 71042cbbe4579bf5db358df345fcb1aa4c898100 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Irene=20L=C3=B3pez?= Date: Tue, 17 Sep 2024 15:44:30 +0100 Subject: [PATCH 29/48] test: refactor `test_feature_factory_return_type` --- .../method/test_l2g/test_feature_factory.py | 32 ++++--------------- 1 file changed, 6 insertions(+), 26 deletions(-) diff --git a/tests/gentropy/method/test_l2g/test_feature_factory.py b/tests/gentropy/method/test_l2g/test_feature_factory.py index 9f829d046..f6543b605 100644 --- a/tests/gentropy/method/test_l2g/test_feature_factory.py +++ b/tests/gentropy/method/test_l2g/test_feature_factory.py @@ -47,43 +47,23 @@ PQtlColocClppMaximumFeature, SQtlColocClppMaximumFeature, TuQtlColocClppMaximumFeature, - ], -) -def test_colocalisation_feature_type( - feature_class: Any, - mock_study_locus: StudyLocus, - mock_colocalisation: Colocalisation, - mock_study_index: StudyIndex, -) -> None: - """Test that every colocalisation feature type returns a set of L2GFeatures.""" - loader = L2GFeatureInputLoader( - colocalisation=mock_colocalisation, study_index=mock_study_index - ) - feature_dataset = feature_class.compute( - study_loci_to_annotate=mock_study_locus, - feature_dependency=loader.get_dependency_by_type( - feature_class.feature_dependency_type - ), - ) - assert isinstance(feature_dataset, L2GFeature) - - -@pytest.mark.parametrize( - "feature_class", - [ DistanceTssMeanFeature, DistanceTssMinimumFeature, DistanceFootprintMeanFeature, DistanceFootprintMinimumFeature, ], ) -def test_distance_feature_type( +def test_feature_factory_return_type( feature_class: Any, mock_study_locus: StudyLocus, + mock_colocalisation: Colocalisation, + mock_study_index: StudyIndex, mock_variant_index: VariantIndex, ) -> None: - """Test that every distance feature type returns a set of L2GFeatures.""" + """Test that every feature factory returns a L2GFeature dataset.""" loader = L2GFeatureInputLoader( + colocalisation=mock_colocalisation, + study_index=mock_study_index, variant_index=mock_variant_index, ) feature_dataset = feature_class.compute( From 99477ab77073208a5a058a5ab0eb4d951d8b3cd3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Irene=20L=C3=B3pez?= Date: Tue, 17 Sep 2024 17:45:27 +0100 Subject: [PATCH 30/48] feat(feature_factory): add all distance neighbourhood features --- src/gentropy/method/l2g/feature_factory.py | 218 +++++++++++++++++- .../method/test_l2g/test_feature_factory.py | 91 ++------ 2 files changed, 234 insertions(+), 75 deletions(-) diff --git a/src/gentropy/method/l2g/feature_factory.py b/src/gentropy/method/l2g/feature_factory.py index 402b597a8..6310f4bb3 100644 --- a/src/gentropy/method/l2g/feature_factory.py +++ b/src/gentropy/method/l2g/feature_factory.py @@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, Any, Iterator, Mapping import pyspark.sql.functions as f +from pyspark.sql import Window from gentropy.common.spark_helpers import convert_from_wide_to_long from gentropy.dataset.colocalisation import Colocalisation @@ -450,8 +451,47 @@ def _common_distance_feature_logic( f.col(distance_type) * f.col("variantInLocusPosteriorProbability"), ) .groupBy("studyLocusId", "geneId") - .agg(agg_expr) - .alias(feature_name) + .agg(agg_expr.alias(feature_name)) + ) + + +def _common_neighbourhood_distance_feature_logic( + study_loci_to_annotate: StudyLocus | L2GGoldStandard, + *, + variant_index: VariantIndex, + feature_name: str, + distance_type: str, + agg_expr: Column, +) -> DataFrame: + """Calculate the neighbourhood distance feature. + + Args: + study_loci_to_annotate (StudyLocus | L2GGoldStandard): The dataset containing study loci that will be used for annotation + variant_index (VariantIndex): The dataset containing distance to gene information + feature_name (str): The name of the feature + distance_type (str): The type of distance to gene + agg_expr (Column): The expression that aggregate distances into a specific way to define the feature + + Returns: + DataFrame: Feature dataset + """ + local_feature_name = feature_name.replace("Neighbourhood", "") + # First compute mean distances to a gene + local_distance = _common_distance_feature_logic( + study_loci_to_annotate, + feature_name=local_feature_name, + distance_type=distance_type, + agg_expr=agg_expr, + variant_index=variant_index, + ) + return ( + # Then compute minimum distance in the vicinity (feature will be the same for any gene associated with a studyLocus) + local_distance.withColumn( + "local_minimum", + f.min(local_feature_name).over(Window.partitionBy("studyLocusId")), + ) + .withColumn(feature_name, f.col("local_minimum") - f.col(local_feature_name)) + .drop("local_minimum") ) @@ -496,6 +536,47 @@ def compute( ) +class DistanceTssMeanNeighbourhoodFeature(L2GFeature): + """Minimum mean distance to TSS for all genes in the vicinity of a studyLocus.""" + + fill_na_value = 500_000 + feature_dependency_type = VariantIndex + feature_name = "distanceTssMeanNeighbourhood" + + @classmethod + def compute( + cls: type[DistanceTssMeanNeighbourhoodFeature], + study_loci_to_annotate: StudyLocus | L2GGoldStandard, + feature_dependency: dict[str, Any], + ) -> DistanceTssMeanNeighbourhoodFeature: + """Computes the feature. + + Args: + study_loci_to_annotate (StudyLocus | L2GGoldStandard): The dataset containing study loci that will be used for annotation + feature_dependency (dict[str, Any]): Dataset that contains the distance information + + Returns: + DistanceTssMeanNeighbourhoodFeature: Feature dataset + """ + agg_expr = f.mean("weightedDistance") + distance_type = "distanceFromTss" + return cls( + _df=convert_from_wide_to_long( + _common_neighbourhood_distance_feature_logic( + study_loci_to_annotate, + feature_name=cls.feature_name, + distance_type=distance_type, + agg_expr=agg_expr, + **feature_dependency, + ), + id_vars=("studyLocusId", "geneId"), + var_name="featureName", + value_name="featureValue", + ), + _schema=cls.get_schema(), + ) + + class DistanceTssMinimumFeature(L2GFeature): """Minimum distance of all tagging variants to gene TSS.""" @@ -537,6 +618,47 @@ def compute( ) +class DistanceTssMinimumNeighbourhoodFeature(L2GFeature): + """Minimum minimum distance to TSS for all genes in the vicinity of a studyLocus.""" + + fill_na_value = 500_000 + feature_dependency_type = VariantIndex + feature_name = "distanceTssMinimumNeighbourhood" + + @classmethod + def compute( + cls: type[DistanceTssMinimumNeighbourhoodFeature], + study_loci_to_annotate: StudyLocus | L2GGoldStandard, + feature_dependency: dict[str, Any], + ) -> DistanceTssMinimumNeighbourhoodFeature: + """Computes the feature. + + Args: + study_loci_to_annotate (StudyLocus | L2GGoldStandard): The dataset containing study loci that will be used for annotation + feature_dependency (dict[str, Any]): Dataset that contains the distance information + + Returns: + DistanceTssMinimumNeighbourhoodFeature: Feature dataset + """ + agg_expr = f.min("weightedDistance") + distance_type = "distanceFromTss" + return cls( + _df=convert_from_wide_to_long( + _common_neighbourhood_distance_feature_logic( + study_loci_to_annotate, + feature_name=cls.feature_name, + distance_type=distance_type, + agg_expr=agg_expr, + **feature_dependency, + ), + id_vars=("studyLocusId", "geneId"), + var_name="featureName", + value_name="featureValue", + ), + _schema=cls.get_schema(), + ) + + class DistanceFootprintMeanFeature(L2GFeature): """Average distance of all tagging variants to the footprint of a gene.""" @@ -578,6 +700,47 @@ def compute( ) +class DistanceFootprintMeanNeighbourhoodFeature(L2GFeature): + """Minimum mean distance to footprint for all genes in the vicinity of a studyLocus.""" + + fill_na_value = 500_000 + feature_dependency_type = VariantIndex + feature_name = "distanceFootprintMeanNeighbourhood" + + @classmethod + def compute( + cls: type[DistanceFootprintMeanNeighbourhoodFeature], + study_loci_to_annotate: StudyLocus | L2GGoldStandard, + feature_dependency: dict[str, Any], + ) -> DistanceFootprintMeanNeighbourhoodFeature: + """Computes the feature. + + Args: + study_loci_to_annotate (StudyLocus | L2GGoldStandard): The dataset containing study loci that will be used for annotation + feature_dependency (dict[str, Any]): Dataset that contains the distance information + + Returns: + DistanceFootprintMeanNeighbourhoodFeature: Feature dataset + """ + agg_expr = f.mean("weightedDistance") + distance_type = "distanceFromFootprint" + return cls( + _df=convert_from_wide_to_long( + _common_neighbourhood_distance_feature_logic( + study_loci_to_annotate, + feature_name=cls.feature_name, + distance_type=distance_type, + agg_expr=agg_expr, + **feature_dependency, + ), + id_vars=("studyLocusId", "geneId"), + var_name="featureName", + value_name="featureValue", + ), + _schema=cls.get_schema(), + ) + + class DistanceFootprintMinimumFeature(L2GFeature): """Minimum distance of all tagging variants to the footprint of a gene.""" @@ -619,12 +782,59 @@ def compute( ) +class DistanceFootprintMinimumNeighbourhoodFeature(L2GFeature): + """Minimum minimum distance to footprint for all genes in the vicinity of a studyLocus.""" + + fill_na_value = 500_000 + feature_dependency_type = VariantIndex + feature_name = "distanceFootprintMinimumNeighbourhood" + + @classmethod + def compute( + cls: type[DistanceFootprintMinimumNeighbourhoodFeature], + study_loci_to_annotate: StudyLocus | L2GGoldStandard, + feature_dependency: dict[str, Any], + ) -> DistanceFootprintMinimumNeighbourhoodFeature: + """Computes the feature. + + Args: + study_loci_to_annotate (StudyLocus | L2GGoldStandard): The dataset containing study loci that will be used for annotation + feature_dependency (dict[str, Any]): Dataset that contains the distance information + + Returns: + DistanceFootprintMinimumNeighbourhoodFeature: Feature dataset + """ + agg_expr = f.min("weightedDistance") + distance_type = "distanceFromFootprint" + return cls( + _df=convert_from_wide_to_long( + _common_neighbourhood_distance_feature_logic( + study_loci_to_annotate, + feature_name=cls.feature_name, + distance_type=distance_type, + agg_expr=agg_expr, + **feature_dependency, + ), + id_vars=("studyLocusId", "geneId"), + var_name="featureName", + value_name="featureValue", + ), + _schema=cls.get_schema(), + ) + + class FeatureFactory: """Factory class for creating features.""" feature_mapper: Mapping[str, type[L2GFeature]] = { - # "distanceTssMinimum": DistanceTssMinimumFeature, - # "distanceTssMean": DistanceTssMeanFeature, + "distanceTssMinimum": DistanceTssMinimumFeature, + "distanceTssMean": DistanceTssMeanFeature, + "distanceTssMeanNeighbourhood": DistanceTssMeanNeighbourhoodFeature, + "distanceTssMinimumNeighbourhood": DistanceTssMinimumNeighbourhoodFeature, + "distanceFootprintMinimum": DistanceFootprintMinimumFeature, + "distanceFootprintMean": DistanceFootprintMeanFeature, + "distanceFootprintMinimumNeighbourhood": DistanceFootprintMinimumNeighbourhoodFeature, + "distanceFootprintMeanNeighbourhood": DistanceFootprintMeanNeighbourhoodFeature, "eQtlColocClppMaximum": EQtlColocClppMaximumFeature, "pQtlColocClppMaximum": PQtlColocClppMaximumFeature, "sQtlColocClppMaximum": SQtlColocClppMaximumFeature, diff --git a/tests/gentropy/method/test_l2g/test_feature_factory.py b/tests/gentropy/method/test_l2g/test_feature_factory.py index f6543b605..53eeb4683 100644 --- a/tests/gentropy/method/test_l2g/test_feature_factory.py +++ b/tests/gentropy/method/test_l2g/test_feature_factory.py @@ -9,19 +9,11 @@ from gentropy.dataset.l2g_feature import L2GFeature from gentropy.dataset.variant_index import VariantIndex from gentropy.method.l2g.feature_factory import ( - DistanceFootprintMeanFeature, - DistanceFootprintMinimumFeature, - DistanceTssMeanFeature, - DistanceTssMinimumFeature, - EQtlColocClppMaximumFeature, - EQtlColocH4MaximumFeature, + DistanceFootprintMeanNeighbourhoodFeature, + DistanceFootprintMinimumNeighbourhoodFeature, + DistanceTssMeanNeighbourhoodFeature, + DistanceTssMinimumNeighbourhoodFeature, L2GFeatureInputLoader, - PQtlColocClppMaximumFeature, - PQtlColocH4MaximumFeature, - SQtlColocClppMaximumFeature, - SQtlColocH4MaximumFeature, - TuQtlColocClppMaximumFeature, - TuQtlColocH4MaximumFeature, ) if TYPE_CHECKING: @@ -30,27 +22,25 @@ from gentropy.dataset.study_locus import StudyLocus -# @pytest.fixture(scope="module") -# def model() -> LocusToGeneModel: -# """Creates an instance of the LocusToGene class.""" -# return LocusToGeneModel(model=RandomForestClassifier()) - - @pytest.mark.parametrize( "feature_class", [ - EQtlColocH4MaximumFeature, - PQtlColocH4MaximumFeature, - SQtlColocH4MaximumFeature, - TuQtlColocH4MaximumFeature, - EQtlColocClppMaximumFeature, - PQtlColocClppMaximumFeature, - SQtlColocClppMaximumFeature, - TuQtlColocClppMaximumFeature, - DistanceTssMeanFeature, - DistanceTssMinimumFeature, - DistanceFootprintMeanFeature, - DistanceFootprintMinimumFeature, + # EQtlColocH4MaximumFeature, + # PQtlColocH4MaximumFeature, + # SQtlColocH4MaximumFeature, + # TuQtlColocH4MaximumFeature, + # EQtlColocClppMaximumFeature, + # PQtlColocClppMaximumFeature, + # SQtlColocClppMaximumFeature, + # TuQtlColocClppMaximumFeature, + # DistanceTssMeanFeature, + # DistanceTssMinimumFeature, + # DistanceFootprintMeanFeature, + # DistanceFootprintMinimumFeature, + DistanceTssMeanNeighbourhoodFeature, + DistanceTssMinimumNeighbourhoodFeature, + DistanceFootprintMeanNeighbourhoodFeature, + DistanceFootprintMinimumNeighbourhoodFeature, ], ) def test_feature_factory_return_type( @@ -77,23 +67,6 @@ def test_feature_factory_return_type( # class TestColocalisationFactory: # """Test the ColocalisationFactory methods.""" - -# def test_get_max_coloc_per_credible_set( -# self: TestColocalisationFactory, -# mock_study_locus: StudyLocus, -# mock_study_index: StudyIndex, -# mock_colocalisation: Colocalisation, -# ) -> None: -# """Test the function that extracts the maximum log likelihood ratio for each pair of overlapping study-locus returns the right data type.""" -# coloc_features = ColocalisationFactory._get_max_coloc_per_credible_set( -# mock_colocalisation, -# mock_study_locus, -# mock_study_index, -# ) -# assert isinstance( -# coloc_features, L2GFeature -# ), "Unexpected type returned from _get_max_coloc_per_credible_set" - # def test_get_max_coloc_per_credible_set_semantic( # self: TestColocalisationFactory, # spark: SparkSession, @@ -179,27 +152,3 @@ def test_feature_factory_return_type( # studies, # ) # assert coloc_features.df.collect() == expected_coloc_features_df.collect() - - -# class TestStudyLocusFactory: -# """Test the StudyLocusFactory methods.""" - -# def test_get_tss_distance_features( -# self: TestStudyLocusFactory, mock_study_locus: StudyLocus, mock_v2g: V2G -# ) -> None: -# """Test the function that extracts the distance to the TSS.""" -# tss_distance = StudyLocusFactory._get_tss_distance_features( -# mock_study_locus, mock_v2g -# ) -# assert isinstance( -# tss_distance, L2GFeature -# ), "Unexpected model type returned from _get_tss_distance_features" - -# def test_get_vep_features( -# self: TestStudyLocusFactory, mock_study_locus: StudyLocus, mock_v2g: V2G -# ) -> None: -# """Test the function that extracts the VEP features.""" -# vep_features = StudyLocusFactory._get_vep_features(mock_study_locus, mock_v2g) -# assert isinstance( -# vep_features, L2GFeature -# ), "Unexpected model type returned from _get_vep_features" From 73e795c48e01ef84c8adcd0a9bcdce7c9fb22d9f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Irene=20L=C3=B3pez?= Date: Tue, 17 Sep 2024 17:51:48 +0100 Subject: [PATCH 31/48] chore: delete v2g --- docs/python_api/datasets/variant_to_gene.md | 9 -- docs/python_api/steps/variant_to_gene_step.md | 5 -- src/gentropy/assets/schemas/v2g.json | 77 ---------------- src/gentropy/config.py | 14 --- src/gentropy/dataset/v2g.py | 43 --------- src/gentropy/variant_to_gene.py | 87 ------------------- tests/gentropy/dataset/test_v2g.py | 23 ----- .../method/test_l2g/test_feature_factory.py | 36 +++++--- 8 files changed, 24 insertions(+), 270 deletions(-) delete mode 100644 docs/python_api/datasets/variant_to_gene.md delete mode 100644 docs/python_api/steps/variant_to_gene_step.md delete mode 100644 src/gentropy/assets/schemas/v2g.json delete mode 100644 src/gentropy/dataset/v2g.py delete mode 100644 src/gentropy/variant_to_gene.py delete mode 100644 tests/gentropy/dataset/test_v2g.py diff --git a/docs/python_api/datasets/variant_to_gene.md b/docs/python_api/datasets/variant_to_gene.md deleted file mode 100644 index 2af67df92..000000000 --- a/docs/python_api/datasets/variant_to_gene.md +++ /dev/null @@ -1,9 +0,0 @@ ---- -title: Variant-to-gene ---- - -::: gentropy.dataset.v2g.V2G - -## Schema - ---8<-- "assets/schemas/v2g.md" diff --git a/docs/python_api/steps/variant_to_gene_step.md b/docs/python_api/steps/variant_to_gene_step.md deleted file mode 100644 index db1c1fd20..000000000 --- a/docs/python_api/steps/variant_to_gene_step.md +++ /dev/null @@ -1,5 +0,0 @@ ---- -title: variant_to_gene ---- - -::: gentropy.variant_to_gene.V2GStep diff --git a/src/gentropy/assets/schemas/v2g.json b/src/gentropy/assets/schemas/v2g.json deleted file mode 100644 index afbe401dd..000000000 --- a/src/gentropy/assets/schemas/v2g.json +++ /dev/null @@ -1,77 +0,0 @@ -{ - "type": "struct", - "fields": [ - { - "name": "geneId", - "type": "string", - "nullable": false, - "metadata": {} - }, - { - "name": "variantId", - "type": "string", - "nullable": false, - "metadata": {} - }, - { - "name": "distance", - "type": "long", - "nullable": true, - "metadata": {} - }, - { - "name": "chromosome", - "type": "string", - "nullable": false, - "metadata": {} - }, - { - "name": "datatypeId", - "type": "string", - "nullable": false, - "metadata": {} - }, - { - "name": "datasourceId", - "type": "string", - "nullable": false, - "metadata": {} - }, - { - "name": "score", - "type": "double", - "nullable": true, - "metadata": {} - }, - { - "name": "resourceScore", - "type": "double", - "nullable": true, - "metadata": {} - }, - { - "name": "pmid", - "type": "string", - "nullable": true, - "metadata": {} - }, - { - "name": "biofeature", - "type": "string", - "nullable": true, - "metadata": {} - }, - { - "name": "variantFunctionalConsequenceId", - "type": "string", - "nullable": true, - "metadata": {} - }, - { - "name": "isHighQualityPlof", - "type": "boolean", - "nullable": true, - "metadata": {} - } - ] -} diff --git a/src/gentropy/config.py b/src/gentropy/config.py index 620e7a12b..e6cb49ab8 100644 --- a/src/gentropy/config.py +++ b/src/gentropy/config.py @@ -327,19 +327,6 @@ class ConvertToVcfStepConfig(StepConfig): _target_: str = "gentropy.variant_index.ConvertToVcfStep" -@dataclass -class VariantToGeneConfig(StepConfig): - """V2G step configuration.""" - - variant_index_path: str = MISSING - gene_index_path: str = MISSING - liftover_chain_file_path: str = MISSING - liftover_max_length_difference: int = 100 - interval_sources: Dict[str, str] = field(default_factory=dict) - v2g_path: str = MISSING - _target_: str = "gentropy.variant_to_gene.V2GStep" - - @dataclass class LocusBreakerClumpingConfig(StepConfig): """Locus breaker clumping step configuration.""" @@ -514,7 +501,6 @@ def register_config() -> None: cs.store(group="step", name="ukb_ppp_eur_sumstat_preprocess", node=UkbPppEurConfig) cs.store(group="step", name="variant_index", node=VariantIndexConfig) cs.store(group="step", name="variant_to_vcf", node=ConvertToVcfStepConfig) - cs.store(group="step", name="variant_to_gene", node=VariantToGeneConfig) cs.store( group="step", name="window_based_clumping", node=WindowBasedClumpingStepConfig ) diff --git a/src/gentropy/dataset/v2g.py b/src/gentropy/dataset/v2g.py deleted file mode 100644 index b14205b17..000000000 --- a/src/gentropy/dataset/v2g.py +++ /dev/null @@ -1,43 +0,0 @@ -"""V2G dataset.""" - -from __future__ import annotations - -from dataclasses import dataclass -from typing import TYPE_CHECKING - -from gentropy.common.schemas import parse_spark_schema -from gentropy.dataset.dataset import Dataset - -if TYPE_CHECKING: - from pyspark.sql.types import StructType - - from gentropy.dataset.gene_index import GeneIndex - - -@dataclass -class V2G(Dataset): - """Variant-to-gene (V2G) evidence dataset. - - A variant-to-gene (V2G) evidence is understood as any piece of evidence that supports the association of a variant with a likely causal gene. The evidence can sometimes be context-specific and refer to specific `biofeatures` (e.g. cell types) - """ - - @classmethod - def get_schema(cls: type[V2G]) -> StructType: - """Provides the schema for the V2G dataset. - - Returns: - StructType: Schema for the V2G dataset - """ - return parse_spark_schema("v2g.json") - - def filter_by_genes(self: V2G, genes: GeneIndex) -> V2G: - """Filter V2G dataset by genes. - - Args: - genes (GeneIndex): Gene index dataset to filter by - - Returns: - V2G: V2G dataset filtered by genes - """ - self.df = self._df.join(genes.df.select("geneId"), on="geneId", how="inner") - return self diff --git a/src/gentropy/variant_to_gene.py b/src/gentropy/variant_to_gene.py deleted file mode 100644 index 60c863f07..000000000 --- a/src/gentropy/variant_to_gene.py +++ /dev/null @@ -1,87 +0,0 @@ -"""Step to generate variant to gene dataset.""" - -from __future__ import annotations - -from functools import reduce - -from gentropy.common.Liftover import LiftOverSpark -from gentropy.common.session import Session -from gentropy.dataset.gene_index import GeneIndex -from gentropy.dataset.intervals import Intervals -from gentropy.dataset.v2g import V2G -from gentropy.dataset.variant_index import VariantIndex - - -class V2GStep: - """Variant-to-gene (V2G) step. - - This step aims to generate a dataset that contains multiple pieces of evidence supporting the functional association of specific variants with genes. Some of the evidence types include: - - 1. Chromatin interaction experiments, e.g. Promoter Capture Hi-C (PCHi-C). - - Attributes: - session (Session): Session object. - variant_index_path (str): Input variant index path. - gene_index_path (str): Input gene index path. - liftover_chain_file_path (str): Path to GRCh37 to GRCh38 chain file. - liftover_max_length_difference: Maximum length difference for liftover. - intervals (dict): Dictionary of interval sources. - v2g_path (str): Output V2G path. - """ - - def __init__( - self, - session: Session, - variant_index_path: str, - gene_index_path: str, - liftover_chain_file_path: str, - interval_sources: dict[str, str], - v2g_path: str, - liftover_max_length_difference: int = 100, - ) -> None: - """Run Variant-to-gene (V2G) step. - - Args: - session (Session): Session object. - variant_index_path (str): Input variant index path. - gene_index_path (str): Input gene index path. - liftover_chain_file_path (str): Path to GRCh37 to GRCh38 chain file. - interval_sources (dict[str, str]): Dictionary of interval sources. - v2g_path (str): Output V2G path. - liftover_max_length_difference (int): Maximum length difference for liftover. - """ - # Read - gene_index = GeneIndex.from_parquet(session, gene_index_path) - vi = VariantIndex.from_parquet(session, variant_index_path).persist() - - # Transform - lift = LiftOverSpark( - # lift over variants to hg38 - liftover_chain_file_path, - liftover_max_length_difference, - ) - - intervals = Intervals( - _df=reduce( - lambda x, y: x.unionByName(y, allowMissingColumns=True), - # create interval instances by parsing each source - [ - Intervals.from_source( - session.spark, source_name, source_path, gene_index, lift - ).df - for source_name, source_path in interval_sources.items() - ], - ), - _schema=Intervals.get_schema(), - ) - v2g = V2G( - _df=intervals.v2g(vi).df.repartition("chromosome"), - _schema=V2G.get_schema(), - ) - - # Load - ( - v2g.df.write.partitionBy("chromosome") - .mode(session.write_mode) - .parquet(v2g_path) - ) diff --git a/tests/gentropy/dataset/test_v2g.py b/tests/gentropy/dataset/test_v2g.py deleted file mode 100644 index 24a917508..000000000 --- a/tests/gentropy/dataset/test_v2g.py +++ /dev/null @@ -1,23 +0,0 @@ -"""Tests V2G dataset.""" - -from __future__ import annotations - -from typing import TYPE_CHECKING - -from gentropy.dataset.v2g import V2G - -if TYPE_CHECKING: - from gentropy.dataset.gene_index import GeneIndex - - -def test_v2g_creation(mock_v2g: V2G) -> None: - """Test v2g creation with mock data.""" - assert isinstance(mock_v2g, V2G) - - -def test_v2g_filter_by_genes(mock_v2g: V2G, mock_gene_index: GeneIndex) -> None: - """Test v2g filter by genes.""" - assert isinstance( - mock_v2g.filter_by_genes(mock_gene_index), - V2G, - ) diff --git a/tests/gentropy/method/test_l2g/test_feature_factory.py b/tests/gentropy/method/test_l2g/test_feature_factory.py index 53eeb4683..359c04b78 100644 --- a/tests/gentropy/method/test_l2g/test_feature_factory.py +++ b/tests/gentropy/method/test_l2g/test_feature_factory.py @@ -9,11 +9,23 @@ from gentropy.dataset.l2g_feature import L2GFeature from gentropy.dataset.variant_index import VariantIndex from gentropy.method.l2g.feature_factory import ( + DistanceFootprintMeanFeature, DistanceFootprintMeanNeighbourhoodFeature, + DistanceFootprintMinimumFeature, DistanceFootprintMinimumNeighbourhoodFeature, + DistanceTssMeanFeature, DistanceTssMeanNeighbourhoodFeature, + DistanceTssMinimumFeature, DistanceTssMinimumNeighbourhoodFeature, + EQtlColocClppMaximumFeature, + EQtlColocH4MaximumFeature, L2GFeatureInputLoader, + PQtlColocClppMaximumFeature, + PQtlColocH4MaximumFeature, + SQtlColocClppMaximumFeature, + SQtlColocH4MaximumFeature, + TuQtlColocClppMaximumFeature, + TuQtlColocH4MaximumFeature, ) if TYPE_CHECKING: @@ -25,18 +37,18 @@ @pytest.mark.parametrize( "feature_class", [ - # EQtlColocH4MaximumFeature, - # PQtlColocH4MaximumFeature, - # SQtlColocH4MaximumFeature, - # TuQtlColocH4MaximumFeature, - # EQtlColocClppMaximumFeature, - # PQtlColocClppMaximumFeature, - # SQtlColocClppMaximumFeature, - # TuQtlColocClppMaximumFeature, - # DistanceTssMeanFeature, - # DistanceTssMinimumFeature, - # DistanceFootprintMeanFeature, - # DistanceFootprintMinimumFeature, + EQtlColocH4MaximumFeature, + PQtlColocH4MaximumFeature, + SQtlColocH4MaximumFeature, + TuQtlColocH4MaximumFeature, + EQtlColocClppMaximumFeature, + PQtlColocClppMaximumFeature, + SQtlColocClppMaximumFeature, + TuQtlColocClppMaximumFeature, + DistanceTssMeanFeature, + DistanceTssMinimumFeature, + DistanceFootprintMeanFeature, + DistanceFootprintMinimumFeature, DistanceTssMeanNeighbourhoodFeature, DistanceTssMinimumNeighbourhoodFeature, DistanceFootprintMeanNeighbourhoodFeature, From 65a6771e63f87b80d2fc7793020619ed0d4e39fc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Irene=20L=C3=B3pez?= Date: Wed, 18 Sep 2024 15:24:26 +0100 Subject: [PATCH 32/48] feat(feature_factory): add all colocalisation neighbourhood features --- src/gentropy/method/l2g/feature_factory.py | 573 +++++++++++++++--- .../method/test_l2g/test_feature_factory.py | 20 +- 2 files changed, 514 insertions(+), 79 deletions(-) diff --git a/src/gentropy/method/l2g/feature_factory.py b/src/gentropy/method/l2g/feature_factory.py index 6310f4bb3..d07219866 100644 --- a/src/gentropy/method/l2g/feature_factory.py +++ b/src/gentropy/method/l2g/feature_factory.py @@ -95,20 +95,62 @@ def _common_colocalisation_feature_logic( Returns: DataFrame: Feature annotation in long format with the columns: studyLocusId, geneId, featureName, featureValue """ - return convert_from_wide_to_long( - colocalisation.extract_maximum_coloc_probability_per_region_and_gene( - study_loci_to_annotate, - study_index, - filter_by_colocalisation_method=colocalisation_method, - filter_by_qtl=qtl_type, - ).selectExpr( - "studyLocusId", - "geneId", - f"{colocalisation_metric} as {feature_name}", - ), - id_vars=("studyLocusId", "geneId"), - var_name="featureName", - value_name="featureValue", + return colocalisation.extract_maximum_coloc_probability_per_region_and_gene( + study_loci_to_annotate, + study_index, + filter_by_colocalisation_method=colocalisation_method, + filter_by_qtl=qtl_type, + ).selectExpr( + "studyLocusId", + "geneId", + f"{colocalisation_metric} as {feature_name}", + ) + + +def _common_neighbourhood_colocalisation_feature_logic( + study_loci_to_annotate: StudyLocus | L2GGoldStandard, + colocalisation_method: str, + colocalisation_metric: str, + feature_name: str, + qtl_type: str, + *, + colocalisation: Colocalisation, + study_index: StudyIndex, +) -> DataFrame: + """Wrapper to call the logic that creates a type of colocalisation features. + + Args: + study_loci_to_annotate (StudyLocus | L2GGoldStandard): The dataset containing study loci that will be used for annotation + colocalisation_method (str): The colocalisation method to filter the data by + colocalisation_metric (str): The colocalisation metric to use + feature_name (str): The name of the feature to create + qtl_type (str): The type of QTL to filter the data by + colocalisation (Colocalisation): Dataset with the colocalisation results + study_index (StudyIndex): Study index to fetch study type and gene + + Returns: + DataFrame: Feature annotation in long format with the columns: studyLocusId, geneId, featureName, featureValue + """ + # First maximum colocalisation score for each studylocus, gene + local_feature_name = feature_name.replace("Neighbourhood", "") + local_max = colocalisation.extract_maximum_coloc_probability_per_region_and_gene( + study_loci_to_annotate, + study_index, + filter_by_colocalisation_method=colocalisation_method, + filter_by_qtl=qtl_type, + ).selectExpr( + "studyLocusId", + "geneId", + f"{colocalisation_metric} as {local_feature_name}", + ) + return ( + # Then compute maximum score in the vicinity (feature will be the same for any gene associated with a studyLocus) + local_max.withColumn( + "regional_maximum", + f.max(local_feature_name).over(Window.partitionBy("studyLocusId")), + ) + .withColumn(feature_name, f.col("regional_maximum") - f.col(local_feature_name)) + .drop("regional_maximum") ) @@ -138,13 +180,61 @@ def compute( qtl_type = "eqtl" return cls( - _df=_common_colocalisation_feature_logic( - study_loci_to_annotate, - colocalisation_method, - colocalisation_metric, - cls.feature_name, - qtl_type, - **feature_dependency, + _df=convert_from_wide_to_long( + _common_colocalisation_feature_logic( + study_loci_to_annotate, + colocalisation_method, + colocalisation_metric, + cls.feature_name, + qtl_type, + **feature_dependency, + ), + id_vars=("studyLocusId", "geneId"), + var_name="featureName", + value_name="featureValue", + ), + _schema=cls.get_schema(), + ) + + +class EQtlColocClppMaximumNeighbourhoodFeature(L2GFeature): + """Max CLPP for each (study, locus) aggregating over all eQTLs.""" + + feature_dependency_type = [Colocalisation, StudyIndex] + feature_name = "eQtlColocClppMaximumNeighbourhood" + + @classmethod + def compute( + cls: type[EQtlColocClppMaximumNeighbourhoodFeature], + study_loci_to_annotate: StudyLocus | L2GGoldStandard, + feature_dependency: dict[str, Any], + ) -> EQtlColocClppMaximumNeighbourhoodFeature: + """Computes the feature. + + Args: + study_loci_to_annotate (StudyLocus | L2GGoldStandard): The dataset containing study loci that will be used for annotation + feature_dependency (dict[str, Any]): Dictionary with the dependencies required. They are passed as keyword arguments. + + Returns: + EQtlColocClppMaximumNeighbourhoodFeature: Feature dataset + """ + colocalisation_method = "ECaviar" + colocalisation_metric = "clpp" + qtl_type = "eqtl" + + return cls( + _df=convert_from_wide_to_long( + _common_neighbourhood_colocalisation_feature_logic( + study_loci_to_annotate, + colocalisation_method, + colocalisation_metric, + cls.feature_name, + qtl_type, + **feature_dependency, + ), + id_vars=("studyLocusId", "geneId"), + var_name="featureName", + value_name="featureValue", ), _schema=cls.get_schema(), ) @@ -175,13 +265,60 @@ def compute( colocalisation_metric = "clpp" qtl_type = "pqtl" return cls( - _df=_common_colocalisation_feature_logic( - study_loci_to_annotate, - colocalisation_method, - colocalisation_metric, - cls.feature_name, - qtl_type, - **feature_dependency, + _df=convert_from_wide_to_long( + _common_colocalisation_feature_logic( + study_loci_to_annotate, + colocalisation_method, + colocalisation_metric, + cls.feature_name, + qtl_type, + **feature_dependency, + ), + id_vars=("studyLocusId", "geneId"), + var_name="featureName", + value_name="featureValue", + ), + _schema=cls.get_schema(), + ) + + +class PQtlColocClppMaximumNeighbourhoodFeature(L2GFeature): + """Max CLPP for each (study, locus, gene) aggregating over all pQTLs.""" + + feature_dependency_type = [Colocalisation, StudyIndex] + feature_name = "pQtlColocClppMaximumNeighbourhood" + + @classmethod + def compute( + cls: type[PQtlColocClppMaximumNeighbourhoodFeature], + study_loci_to_annotate: StudyLocus | L2GGoldStandard, + feature_dependency: dict[str, Any], + ) -> PQtlColocClppMaximumNeighbourhoodFeature: + """Computes the feature. + + Args: + study_loci_to_annotate (StudyLocus | L2GGoldStandard): The dataset containing study loci that will be used for annotation + feature_dependency (dict[str, Any]): Dataset with the colocalisation results + + Returns: + PQtlColocClppMaximumNeighbourhoodFeature: Feature dataset + """ + colocalisation_method = "ECaviar" + colocalisation_metric = "clpp" + qtl_type = "pqtl" + return cls( + _df=convert_from_wide_to_long( + _common_neighbourhood_colocalisation_feature_logic( + study_loci_to_annotate, + colocalisation_method, + colocalisation_metric, + cls.feature_name, + qtl_type, + **feature_dependency, + ), + id_vars=("studyLocusId", "geneId"), + var_name="featureName", + value_name="featureValue", ), _schema=cls.get_schema(), ) @@ -212,13 +349,60 @@ def compute( colocalisation_metric = "clpp" qtl_type = "sqtl" return cls( - _df=_common_colocalisation_feature_logic( - study_loci_to_annotate, - colocalisation_method, - colocalisation_metric, - cls.feature_name, - qtl_type, - **feature_dependency, + _df=convert_from_wide_to_long( + _common_colocalisation_feature_logic( + study_loci_to_annotate, + colocalisation_method, + colocalisation_metric, + cls.feature_name, + qtl_type, + **feature_dependency, + ), + id_vars=("studyLocusId", "geneId"), + var_name="featureName", + value_name="featureValue", + ), + _schema=cls.get_schema(), + ) + + +class SQtlColocClppMaximumNeighbourhoodFeature(L2GFeature): + """Max CLPP for each (study, locus, gene) aggregating over all sQTLs.""" + + feature_dependency_type = [Colocalisation, StudyIndex] + feature_name = "sQtlColocClppMaximumNeighbourhood" + + @classmethod + def compute( + cls: type[SQtlColocClppMaximumNeighbourhoodFeature], + study_loci_to_annotate: StudyLocus | L2GGoldStandard, + feature_dependency: dict[str, Any], + ) -> SQtlColocClppMaximumNeighbourhoodFeature: + """Computes the feature. + + Args: + study_loci_to_annotate (StudyLocus | L2GGoldStandard): The dataset containing study loci that will be used for annotation + feature_dependency (dict[str, Any]): Dataset with the colocalisation results + + Returns: + SQtlColocClppMaximumNeighbourhoodFeature: Feature dataset + """ + colocalisation_method = "ECaviar" + colocalisation_metric = "clpp" + qtl_type = "sqtl" + return cls( + _df=convert_from_wide_to_long( + _common_neighbourhood_colocalisation_feature_logic( + study_loci_to_annotate, + colocalisation_method, + colocalisation_metric, + cls.feature_name, + qtl_type, + **feature_dependency, + ), + id_vars=("studyLocusId", "geneId"), + var_name="featureName", + value_name="featureValue", ), _schema=cls.get_schema(), ) @@ -249,20 +433,67 @@ def compute( colocalisation_metric = "clpp" qtl_type = "tuqtl" return cls( - _df=_common_colocalisation_feature_logic( - study_loci_to_annotate, - colocalisation_method, - colocalisation_metric, - cls.feature_name, - qtl_type, - **feature_dependency, + _df=convert_from_wide_to_long( + _common_colocalisation_feature_logic( + study_loci_to_annotate, + colocalisation_method, + colocalisation_metric, + cls.feature_name, + qtl_type, + **feature_dependency, + ), + id_vars=("studyLocusId", "geneId"), + var_name="featureName", + value_name="featureValue", + ), + _schema=cls.get_schema(), + ) + + +class TuQtlColocClppMaximumNeighbourhoodFeature(L2GFeature): + """Max CLPP for each (study, locus) aggregating over all tuQTLs.""" + + feature_dependency_type = [Colocalisation, StudyIndex] + feature_name = "tuQtlColocClppMaximumNeighbourhood" + + @classmethod + def compute( + cls: type[TuQtlColocClppMaximumNeighbourhoodFeature], + study_loci_to_annotate: StudyLocus | L2GGoldStandard, + feature_dependency: dict[str, Any], + ) -> TuQtlColocClppMaximumNeighbourhoodFeature: + """Computes the feature. + + Args: + study_loci_to_annotate (StudyLocus | L2GGoldStandard): The dataset containing study loci that will be used for annotation + feature_dependency (dict[str, Any]): Dataset with the colocalisation results + + Returns: + TuQtlColocClppMaximumNeighbourhoodFeature: Feature dataset + """ + colocalisation_method = "ECaviar" + colocalisation_metric = "clpp" + qtl_type = "tuqtl" + return cls( + _df=convert_from_wide_to_long( + _common_neighbourhood_colocalisation_feature_logic( + study_loci_to_annotate, + colocalisation_method, + colocalisation_metric, + cls.feature_name, + qtl_type, + **feature_dependency, + ), + id_vars=("studyLocusId", "geneId"), + var_name="featureName", + value_name="featureValue", ), _schema=cls.get_schema(), ) class EQtlColocH4MaximumFeature(L2GFeature): - """Max CLPP for each (study, locus, gene) aggregating over all eQTLs.""" + """Max H4 for each (study, locus, gene) aggregating over all eQTLs.""" feature_dependency_type = [Colocalisation, StudyIndex] feature_name = "eQtlColocH4Maximum" @@ -286,13 +517,60 @@ def compute( colocalisation_metric = "h4" qtl_type = "eqtl" return cls( - _df=_common_colocalisation_feature_logic( - study_loci_to_annotate, - colocalisation_method, - colocalisation_metric, - cls.feature_name, - qtl_type, - **feature_dependency, + _df=convert_from_wide_to_long( + _common_colocalisation_feature_logic( + study_loci_to_annotate, + colocalisation_method, + colocalisation_metric, + cls.feature_name, + qtl_type, + **feature_dependency, + ), + id_vars=("studyLocusId", "geneId"), + var_name="featureName", + value_name="featureValue", + ), + _schema=cls.get_schema(), + ) + + +class EQtlColocH4MaximumNeighbourhoodFeature(L2GFeature): + """Max H4 for each (study, locus) aggregating over all eQTLs.""" + + feature_dependency_type = [Colocalisation, StudyIndex] + feature_name = "eQtlColocH4MaximumNeighbourhood" + + @classmethod + def compute( + cls: type[EQtlColocH4MaximumNeighbourhoodFeature], + study_loci_to_annotate: StudyLocus | L2GGoldStandard, + feature_dependency: dict[str, Any], + ) -> EQtlColocH4MaximumNeighbourhoodFeature: + """Computes the feature. + + Args: + study_loci_to_annotate (StudyLocus | L2GGoldStandard): The dataset containing study loci that will be used for annotation + feature_dependency (dict[str, Any]): Dataset with the colocalisation results + + Returns: + EQtlColocH4MaximumNeighbourhoodFeature: Feature dataset + """ + colocalisation_method = "Coloc" + colocalisation_metric = "h4" + qtl_type = "eqtl" + return cls( + _df=convert_from_wide_to_long( + _common_neighbourhood_colocalisation_feature_logic( + study_loci_to_annotate, + colocalisation_method, + colocalisation_metric, + cls.feature_name, + qtl_type, + **feature_dependency, + ), + id_vars=("studyLocusId", "geneId"), + var_name="featureName", + value_name="featureValue", ), _schema=cls.get_schema(), ) @@ -323,20 +601,67 @@ def compute( colocalisation_metric = "h4" qtl_type = "pqtl" return cls( - _df=_common_colocalisation_feature_logic( - study_loci_to_annotate, - colocalisation_method, - colocalisation_metric, - cls.feature_name, - qtl_type, - **feature_dependency, + _df=convert_from_wide_to_long( + _common_colocalisation_feature_logic( + study_loci_to_annotate, + colocalisation_method, + colocalisation_metric, + cls.feature_name, + qtl_type, + **feature_dependency, + ), + id_vars=("studyLocusId", "geneId"), + var_name="featureName", + value_name="featureValue", + ), + _schema=cls.get_schema(), + ) + + +class PQtlColocH4MaximumNeighbourhoodFeature(L2GFeature): + """Max H4 for each (study, locus) aggregating over all pQTLs.""" + + feature_dependency_type = [Colocalisation, StudyIndex] + feature_name = "pQtlColocH4MaximumNeighbourhood" + + @classmethod + def compute( + cls: type[PQtlColocH4MaximumNeighbourhoodFeature], + study_loci_to_annotate: StudyLocus | L2GGoldStandard, + feature_dependency: dict[str, Any], + ) -> PQtlColocH4MaximumNeighbourhoodFeature: + """Computes the feature. + + Args: + study_loci_to_annotate (StudyLocus | L2GGoldStandard): The dataset containing study loci that will be used for annotation + feature_dependency (dict[str, Any]): Dataset with the colocalisation results + + Returns: + PQtlColocH4MaximumNeighbourhoodFeature: Feature dataset + """ + colocalisation_method = "Coloc" + colocalisation_metric = "h4" + qtl_type = "pqtl" + return cls( + _df=convert_from_wide_to_long( + _common_neighbourhood_colocalisation_feature_logic( + study_loci_to_annotate, + colocalisation_method, + colocalisation_metric, + cls.feature_name, + qtl_type, + **feature_dependency, + ), + id_vars=("studyLocusId", "geneId"), + var_name="featureName", + value_name="featureValue", ), _schema=cls.get_schema(), ) class SQtlColocH4MaximumFeature(L2GFeature): - """Max CLPP for each (study, locus, gene) aggregating over all sQTLs.""" + """Max H4 for each (study, locus, gene) aggregating over all sQTLs.""" feature_dependency_type = [Colocalisation, StudyIndex] feature_name = "sQtlColocH4Maximum" @@ -360,13 +685,60 @@ def compute( colocalisation_metric = "h4" qtl_type = "sqtl" return cls( - _df=_common_colocalisation_feature_logic( - study_loci_to_annotate, - colocalisation_method, - colocalisation_metric, - cls.feature_name, - qtl_type, - **feature_dependency, + _df=convert_from_wide_to_long( + _common_colocalisation_feature_logic( + study_loci_to_annotate, + colocalisation_method, + colocalisation_metric, + cls.feature_name, + qtl_type, + **feature_dependency, + ), + id_vars=("studyLocusId", "geneId"), + var_name="featureName", + value_name="featureValue", + ), + _schema=cls.get_schema(), + ) + + +class SQtlColocH4MaximumNeighbourhoodFeature(L2GFeature): + """Max H4 for each (study, locus) aggregating over all sQTLs.""" + + feature_dependency_type = [Colocalisation, StudyIndex] + feature_name = "sQtlColocH4MaximumNeighbourhood" + + @classmethod + def compute( + cls: type[SQtlColocH4MaximumNeighbourhoodFeature], + study_loci_to_annotate: StudyLocus | L2GGoldStandard, + feature_dependency: dict[str, Any], + ) -> SQtlColocH4MaximumNeighbourhoodFeature: + """Computes the feature. + + Args: + study_loci_to_annotate (StudyLocus | L2GGoldStandard): The dataset containing study loci that will be used for annotation + feature_dependency (dict[str, Any]): Dataset with the colocalisation results + + Returns: + SQtlColocH4MaximumNeighbourhoodFeature: Feature dataset + """ + colocalisation_method = "Coloc" + colocalisation_metric = "h4" + qtl_type = "sqtl" + return cls( + _df=convert_from_wide_to_long( + _common_neighbourhood_colocalisation_feature_logic( + study_loci_to_annotate, + colocalisation_method, + colocalisation_metric, + cls.feature_name, + qtl_type, + **feature_dependency, + ), + id_vars=("studyLocusId", "geneId"), + var_name="featureName", + value_name="featureValue", ), _schema=cls.get_schema(), ) @@ -397,13 +769,60 @@ def compute( colocalisation_metric = "h4" qtl_type = "tuqtl" return cls( - _df=_common_colocalisation_feature_logic( - study_loci_to_annotate, - colocalisation_method, - colocalisation_metric, - cls.feature_name, - qtl_type, - **feature_dependency, + _df=convert_from_wide_to_long( + _common_colocalisation_feature_logic( + study_loci_to_annotate, + colocalisation_method, + colocalisation_metric, + cls.feature_name, + qtl_type, + **feature_dependency, + ), + id_vars=("studyLocusId", "geneId"), + var_name="featureName", + value_name="featureValue", + ), + _schema=cls.get_schema(), + ) + + +class TuQtlColocH4MaximumNeighbourhoodFeature(L2GFeature): + """Max H4 for each (study, locus) aggregating over all tuQTLs.""" + + feature_dependency_type = [Colocalisation, StudyIndex] + feature_name = "tuQtlColocH4MaximumNeighbourhood" + + @classmethod + def compute( + cls: type[TuQtlColocH4MaximumNeighbourhoodFeature], + study_loci_to_annotate: StudyLocus | L2GGoldStandard, + feature_dependency: dict[str, Any], + ) -> TuQtlColocH4MaximumNeighbourhoodFeature: + """Computes the feature. + + Args: + study_loci_to_annotate (StudyLocus | L2GGoldStandard): The dataset containing study loci that will be used for annotation + feature_dependency (dict[str, Any]): Dataset with the colocalisation results + + Returns: + TuQtlColocH4MaximumNeighbourhoodFeature: Feature dataset + """ + colocalisation_method = "Coloc" + colocalisation_metric = "h4" + qtl_type = "tuqtl" + return cls( + _df=convert_from_wide_to_long( + _common_colocalisation_feature_logic( + study_loci_to_annotate, + colocalisation_method, + colocalisation_metric, + cls.feature_name, + qtl_type, + **feature_dependency, + ), + id_vars=("studyLocusId", "geneId"), + var_name="featureName", + value_name="featureValue", ), _schema=cls.get_schema(), ) @@ -477,7 +896,7 @@ def _common_neighbourhood_distance_feature_logic( """ local_feature_name = feature_name.replace("Neighbourhood", "") # First compute mean distances to a gene - local_distance = _common_distance_feature_logic( + local_min = _common_distance_feature_logic( study_loci_to_annotate, feature_name=local_feature_name, distance_type=distance_type, @@ -486,12 +905,12 @@ def _common_neighbourhood_distance_feature_logic( ) return ( # Then compute minimum distance in the vicinity (feature will be the same for any gene associated with a studyLocus) - local_distance.withColumn( - "local_minimum", + local_min.withColumn( + "regional_minimum", f.min(local_feature_name).over(Window.partitionBy("studyLocusId")), ) - .withColumn(feature_name, f.col("local_minimum") - f.col(local_feature_name)) - .drop("local_minimum") + .withColumn(feature_name, f.col("regional_minimum") - f.col(local_feature_name)) + .drop("regional_minimum") ) diff --git a/tests/gentropy/method/test_l2g/test_feature_factory.py b/tests/gentropy/method/test_l2g/test_feature_factory.py index 359c04b78..30c7c39ea 100644 --- a/tests/gentropy/method/test_l2g/test_feature_factory.py +++ b/tests/gentropy/method/test_l2g/test_feature_factory.py @@ -1,4 +1,4 @@ -"""Test locus-to-gene model training.""" +"""Test locus-to-gene feature generation.""" from __future__ import annotations @@ -7,7 +7,6 @@ import pytest from gentropy.dataset.l2g_feature import L2GFeature -from gentropy.dataset.variant_index import VariantIndex from gentropy.method.l2g.feature_factory import ( DistanceFootprintMeanFeature, DistanceFootprintMeanNeighbourhoodFeature, @@ -18,20 +17,29 @@ DistanceTssMinimumFeature, DistanceTssMinimumNeighbourhoodFeature, EQtlColocClppMaximumFeature, + EQtlColocClppMaximumNeighbourhoodFeature, EQtlColocH4MaximumFeature, + EQtlColocH4MaximumNeighbourhoodFeature, L2GFeatureInputLoader, PQtlColocClppMaximumFeature, + PQtlColocClppMaximumNeighbourhoodFeature, PQtlColocH4MaximumFeature, + PQtlColocH4MaximumNeighbourhoodFeature, SQtlColocClppMaximumFeature, + SQtlColocClppMaximumNeighbourhoodFeature, SQtlColocH4MaximumFeature, + SQtlColocH4MaximumNeighbourhoodFeature, TuQtlColocClppMaximumFeature, + TuQtlColocClppMaximumNeighbourhoodFeature, TuQtlColocH4MaximumFeature, + TuQtlColocH4MaximumNeighbourhoodFeature, ) if TYPE_CHECKING: from gentropy.dataset.colocalisation import Colocalisation from gentropy.dataset.study_index import StudyIndex from gentropy.dataset.study_locus import StudyLocus + from gentropy.dataset.variant_index import VariantIndex @pytest.mark.parametrize( @@ -45,6 +53,14 @@ PQtlColocClppMaximumFeature, SQtlColocClppMaximumFeature, TuQtlColocClppMaximumFeature, + EQtlColocClppMaximumNeighbourhoodFeature, + PQtlColocClppMaximumNeighbourhoodFeature, + SQtlColocClppMaximumNeighbourhoodFeature, + TuQtlColocClppMaximumNeighbourhoodFeature, + EQtlColocH4MaximumNeighbourhoodFeature, + PQtlColocH4MaximumNeighbourhoodFeature, + SQtlColocH4MaximumNeighbourhoodFeature, + TuQtlColocH4MaximumNeighbourhoodFeature, DistanceTssMeanFeature, DistanceTssMinimumFeature, DistanceFootprintMeanFeature, From f4f8ae0d957f027d338d0b38836a9fcf8fe5d420 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Irene=20L=C3=B3pez?= Date: Wed, 18 Sep 2024 16:50:19 +0100 Subject: [PATCH 33/48] chore: final v2g deletion --- config/step/ot_locus_to_gene_predict.yaml | 2 +- config/step/ot_locus_to_gene_train.yaml | 2 +- docs/howto/command_line/run_step_in_cli.md | 1 - notebooks/Release_QC_metrics.ipynb | 85 +------------------ src/gentropy/config.py | 4 +- src/gentropy/dataset/intervals.py | 34 +------- src/gentropy/dataset/l2g_gold_standard.py | 10 ++- .../open_targets/l2g_gold_standard.py | 28 +++--- src/gentropy/l2g.py | 27 +++--- tests/gentropy/conftest.py | 26 ------ tests/gentropy/dataset/test_intervals.py | 18 ---- .../open_targets/test_l2g_gold_standard.py | 62 ++++++++++---- tests/gentropy/test_schemas.py | 30 ++++--- 13 files changed, 115 insertions(+), 214 deletions(-) delete mode 100644 tests/gentropy/dataset/test_intervals.py diff --git a/config/step/ot_locus_to_gene_predict.yaml b/config/step/ot_locus_to_gene_predict.yaml index 97080223a..dffb36a61 100644 --- a/config/step/ot_locus_to_gene_predict.yaml +++ b/config/step/ot_locus_to_gene_predict.yaml @@ -5,6 +5,6 @@ run_mode: predict model_path: null predictions_path: ${datasets.l2g_predictions} credible_set_path: ${datasets.credible_set} -variant_gene_path: ${datasets.variant_to_gene} +variant_index_path: ${datasets.variant_index} colocalisation_path: ${datasets.colocalisation} study_index_path: ${datasets.study_index} diff --git a/config/step/ot_locus_to_gene_train.yaml b/config/step/ot_locus_to_gene_train.yaml index 181e1303b..5b2b43d90 100644 --- a/config/step/ot_locus_to_gene_train.yaml +++ b/config/step/ot_locus_to_gene_train.yaml @@ -7,7 +7,7 @@ hf_hub_repo_id: opentargets/locus_to_gene model_path: ${datasets.l2g_model} predictions_path: ${datasets.l2g_predictions} credible_set_path: ${datasets.credible_set} -variant_gene_path: ${datasets.variant_to_gene} +variant_index_path: ${datasets.variant_index} colocalisation_path: ${datasets.colocalisation} study_index_path: ${datasets.study_index} gold_standard_curation_path: ${datasets.l2g_gold_standard_curation} diff --git a/docs/howto/command_line/run_step_in_cli.md b/docs/howto/command_line/run_step_in_cli.md index ac7d55ff9..b935d84fb 100644 --- a/docs/howto/command_line/run_step_in_cli.md +++ b/docs/howto/command_line/run_step_in_cli.md @@ -24,7 +24,6 @@ Available options: ukbiobank variant_annotation variant_index - variant_to_gene Set the environment variable HYDRA_FULL_ERROR=1 for a complete stack trace. ``` diff --git a/notebooks/Release_QC_metrics.ipynb b/notebooks/Release_QC_metrics.ipynb index aa0924711..4eb27015b 100644 --- a/notebooks/Release_QC_metrics.ipynb +++ b/notebooks/Release_QC_metrics.ipynb @@ -13,21 +13,17 @@ "1. Import necessary modules and set up the release path and version.\n", "2. Load and analyze the variant index data:\n", " - Count the number of unique variants.\n", - "3. Load and analyze the variant-to-gene (v2g) data:\n", - " - Count the number of unique variants and total variant-to-gene assignments.\n", - " - Count the number of v2g assignments where the score is > 0.8.\n", - " - Plot a histogram/density plot for the \"score\" column.\n", - "4. Load and analyze the study index data for different data sources (FinnGen, GWASCat, eQTLcat):\n", + "3. Load and analyze the study index data for different data sources (FinnGen, GWASCat, eQTLcat):\n", " - Count the number of unique studies for each data source.\n", - "5. Analyze the credible sets for each datasource (Finngen, gwascat, eqtlcat):\n", + "4. Analyze the credible sets for each datasource (Finngen, gwascat, eqtlcat):\n", " - Analyze the credible sets:\n", " - Count the number of unique credible sets and unique study IDs.\n", " - Plot a scatter plot of the credible set size vs. the top posterior probability.\n", " - Count the number of credible sets with a top SNP posterior probability > 0.9..\n", - "6. Analyze colocalization data:\n", + "5. Analyze colocalization data:\n", " - Count the total number of colocalizations and the number with clpp > 0.8.\n", " - Calculate the average number of overlaps per credible set.\n", - "7. Analyze locus-to-gene (L2G) predictions:\n", + "6. Analyze locus-to-gene (L2G) predictions:\n", " - Load the locus-to-gene predictions data.\n", " - How many Studylocus contains a \"good\" l2g prediction? (l2g_score > 0.5)\n", " - How does l2g perform based on different datasource inputs? (impossible to tell)\n", @@ -126,79 +122,6 @@ "#variant_index.filter(variant_index[\"alleleFrequencies.populationName\"] > 0.05).show(10, False)" ] }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "\n", - "#### 3. Load and analyze the variant-to-gene (v2g) data:\n", - " - Count the number of unique variants and total variant-to-gene assignments.\n", - " - Count the number of v2g assignments where the score is > 0.8." - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - " \r" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Unique variants in v2g release: 5090991 , total variant to gene assignments: 105771851 , number of v2g assignments where score > 0.8: 23176515 ( 4.552 %)\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - " \r" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Summary of v2g_score: Mean: 0.5909395615801637 L.quart: 0.29 Median: 0.62 U.quart: 0.94\n" - ] - } - ], - "source": [ - "#v2g_path='gs://genetics_etl_python_playground/releases/24.03/variant_to_gene'\n", - "v2g_path=f\"{release_path}/{release_ver}/variant_to_gene\"\n", - "v2g=session.spark.read.parquet(v2g_path, recursiveFileLookup=True)\n", - "\n", - "#How many variants?\n", - "sample_size_quartiles = v2g.stat.approxQuantile(\"score\", [0.25, 0.5, 0.75], 0.01)\n", - "#v2g.select().toPandas().plot.hist()\n", - "#v2g.show()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - " - Plot a histogram/density plot for the \"score\" column." - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [], - "source": [ - "#The histogram/density plot for “score”\n", - "# Out of mem error:\n", - "#v2g.select(f.col(\"score\")).toPandas().plot.hist(bins=10, alpha=0.5, label=\"v2g scores\")" - ] - }, { "cell_type": "markdown", "metadata": {}, diff --git a/src/gentropy/config.py b/src/gentropy/config.py index e6cb49ab8..fca9b4217 100644 --- a/src/gentropy/config.py +++ b/src/gentropy/config.py @@ -2,7 +2,7 @@ import os from dataclasses import dataclass, field -from typing import Any, Dict, List +from typing import Any, List from hail import __file__ as hail_location from hydra.core.config_store import ConfigStore @@ -218,7 +218,7 @@ class LocusToGeneConfig(StepConfig): run_mode: str = MISSING predictions_path: str = MISSING credible_set_path: str = MISSING - variant_gene_path: str = MISSING + variant_index_path: str = MISSING colocalisation_path: str = MISSING study_index_path: str = MISSING model_path: str | None = None diff --git a/src/gentropy/dataset/intervals.py b/src/gentropy/dataset/intervals.py index c3b9136c9..37158810b 100644 --- a/src/gentropy/dataset/intervals.py +++ b/src/gentropy/dataset/intervals.py @@ -1,22 +1,19 @@ """Interval dataset.""" + from __future__ import annotations from dataclasses import dataclass from typing import TYPE_CHECKING -import pyspark.sql.functions as f - from gentropy.common.Liftover import LiftOverSpark from gentropy.common.schemas import parse_spark_schema from gentropy.dataset.dataset import Dataset from gentropy.dataset.gene_index import GeneIndex -from gentropy.dataset.v2g import V2G if TYPE_CHECKING: from pyspark.sql import SparkSession from pyspark.sql.types import StructType - from gentropy.dataset.variant_index import VariantIndex @dataclass @@ -74,32 +71,3 @@ def from_source( source_class = source_to_class[source_name] data = source_class.read(spark, source_path) # type: ignore return source_class.parse(data, gene_index, lift) # type: ignore - - def v2g(self: Intervals, variant_index: VariantIndex) -> V2G: - """Convert intervals into V2G by intersecting with a variant index. - - Args: - variant_index (VariantIndex): Variant index dataset - - Returns: - V2G: Variant-to-gene evidence dataset - """ - return V2G( - _df=( - self.df.alias("interval") - .join( - variant_index.df.selectExpr( - "chromosome as vi_chromosome", "variantId", "position" - ).alias("vi"), - on=[ - f.col("vi.vi_chromosome") == f.col("interval.chromosome"), - f.col("vi.position").between( - f.col("interval.start"), f.col("interval.end") - ), - ], - how="inner", - ) - .drop("start", "end", "vi_chromosome", "position") - ), - _schema=V2G.get_schema(), - ) diff --git a/src/gentropy/dataset/l2g_gold_standard.py b/src/gentropy/dataset/l2g_gold_standard.py index 89f4c5f5d..064f6cc0e 100644 --- a/src/gentropy/dataset/l2g_gold_standard.py +++ b/src/gentropy/dataset/l2g_gold_standard.py @@ -18,7 +18,7 @@ from gentropy.dataset.l2g_feature_matrix import L2GFeatureMatrix from gentropy.dataset.study_locus_overlap import StudyLocusOverlap - from gentropy.dataset.v2g import V2G + from gentropy.dataset.variant_index import VariantIndex @dataclass @@ -33,16 +33,16 @@ class L2GGoldStandard(Dataset): def from_otg_curation( cls: type[L2GGoldStandard], gold_standard_curation: DataFrame, - v2g: V2G, study_locus_overlap: StudyLocusOverlap, + variant_index: VariantIndex, interactions: DataFrame, ) -> L2GGoldStandard: """Initialise L2GGoldStandard from source dataset. Args: gold_standard_curation (DataFrame): Gold standard curation dataframe, extracted from - v2g (V2G): Variant to gene dataset to bring distance between a variant and a gene's TSS study_locus_overlap (StudyLocusOverlap): Study locus overlap dataset to remove duplicated loci + variant_index (VariantIndex): Dataset to bring distance between a variant and a gene's footprint interactions (DataFrame): Gene-gene interactions dataset to remove negative cases where the gene interacts with a positive gene Returns: @@ -55,7 +55,9 @@ def from_otg_curation( interactions_df = cls.process_gene_interactions(interactions) return ( - OpenTargetsL2GGoldStandard.as_l2g_gold_standard(gold_standard_curation, v2g) + OpenTargetsL2GGoldStandard.as_l2g_gold_standard( + gold_standard_curation, variant_index + ) # .filter_unique_associations(study_locus_overlap) .remove_false_negatives(interactions_df) ) diff --git a/src/gentropy/datasource/open_targets/l2g_gold_standard.py b/src/gentropy/datasource/open_targets/l2g_gold_standard.py index 2cfcd62f8..0deb2e2a7 100644 --- a/src/gentropy/datasource/open_targets/l2g_gold_standard.py +++ b/src/gentropy/datasource/open_targets/l2g_gold_standard.py @@ -9,7 +9,7 @@ from gentropy.dataset.l2g_gold_standard import L2GGoldStandard from gentropy.dataset.study_locus import StudyLocus -from gentropy.dataset.v2g import V2G +from gentropy.dataset.variant_index import VariantIndex class OpenTargetsL2GGoldStandard: @@ -60,7 +60,9 @@ def parse_positive_curation( @classmethod def expand_gold_standard_with_negatives( - cls: Type[OpenTargetsL2GGoldStandard], positive_set: DataFrame, v2g: V2G + cls: Type[OpenTargetsL2GGoldStandard], + positive_set: DataFrame, + variant_index: VariantIndex, ) -> DataFrame: """Create full set of positive and negative evidence of locus to gene associations. @@ -68,7 +70,7 @@ def expand_gold_standard_with_negatives( Args: positive_set (DataFrame): Positive set from curation - v2g (V2G): Variant to gene dataset to bring distance between a variant and a gene's TSS + variant_index (VariantIndex): Variant index to get distance to gene Returns: DataFrame: Full set of positive and negative evidence of locus to gene associations @@ -76,9 +78,13 @@ def expand_gold_standard_with_negatives( return ( positive_set.withColumnRenamed("geneId", "curated_geneId") .join( - v2g.df.selectExpr( - "variantId", "geneId as non_curated_geneId", "distance" - ).filter(f.col("distance") <= cls.LOCUS_TO_GENE_WINDOW), + variant_index.get_distance_to_gene() + .selectExpr( + "variantId", + "targetId as non_curated_geneId", + "distanceFromTss", + ) + .filter(f.col("distanceFromTss") <= cls.LOCUS_TO_GENE_WINDOW), on="variantId", how="left", ) @@ -86,7 +92,7 @@ def expand_gold_standard_with_negatives( "goldStandardSet", f.when( (f.col("curated_geneId") == f.col("non_curated_geneId")) - # to keep the positives that are outside the v2g dataset + # to keep the positives that are not part of the variant index | (f.col("non_curated_geneId").isNull()), f.lit(L2GGoldStandard.GS_POSITIVE_LABEL), ).otherwise(L2GGoldStandard.GS_NEGATIVE_LABEL), @@ -98,27 +104,27 @@ def expand_gold_standard_with_negatives( f.col("curated_geneId"), ).otherwise(f.col("non_curated_geneId")), ) - .drop("distance", "curated_geneId", "non_curated_geneId") + .drop("distanceFromTss", "curated_geneId", "non_curated_geneId") ) @classmethod def as_l2g_gold_standard( cls: type[OpenTargetsL2GGoldStandard], gold_standard_curation: DataFrame, - v2g: V2G, + variant_index: VariantIndex, ) -> L2GGoldStandard: """Initialise L2GGoldStandard from source dataset. Args: gold_standard_curation (DataFrame): Gold standard curation dataframe, extracted from https://github.com/opentargets/genetics-gold-standards - v2g (V2G): Variant to gene dataset to bring distance between a variant and a gene's TSS + variant_index (VariantIndex): Dataset to bring distance between a variant and a gene's footprint Returns: L2GGoldStandard: L2G Gold Standard dataset. False negatives have not yet been removed. """ return L2GGoldStandard( _df=cls.parse_positive_curation(gold_standard_curation).transform( - cls.expand_gold_standard_with_negatives, v2g + cls.expand_gold_standard_with_negatives, variant_index ), _schema=L2GGoldStandard.get_schema(), ) diff --git a/src/gentropy/l2g.py b/src/gentropy/l2g.py index 7ce77eaa1..7722ba894 100644 --- a/src/gentropy/l2g.py +++ b/src/gentropy/l2g.py @@ -17,7 +17,7 @@ from gentropy.dataset.l2g_prediction import L2GPrediction from gentropy.dataset.study_index import StudyIndex from gentropy.dataset.study_locus import StudyLocus -from gentropy.dataset.v2g import V2G +from gentropy.dataset.variant_index import VariantIndex from gentropy.method.l2g.feature_factory import L2GFeatureInputLoader from gentropy.method.l2g.model import LocusToGeneModel from gentropy.method.l2g.trainer import LocusToGeneTrainer @@ -38,7 +38,7 @@ def __init__( model_path: str | None = None, credible_set_path: str, gold_standard_curation_path: str | None = None, - variant_gene_path: str | None = None, + variant_index_path: str | None = None, colocalisation_path: str | None = None, study_index_path: str | None = None, gene_interactions_path: str | None = None, @@ -59,7 +59,7 @@ def __init__( model_path (str | None): Path to the model. It can be either in the filesystem or the name on the Hugging Face Hub (in the form of username/repo_name). credible_set_path (str): Path to the credible set dataset necessary to build the feature matrix gold_standard_curation_path (str | None): Path to the gold standard curation file - variant_gene_path (str | None): Path to the variant-gene dataset + variant_index_path (str | None): Path to the variant index dataset colocalisation_path (str | None): Path to the colocalisation dataset study_index_path (str | None): Path to the study index dataset gene_interactions_path (str | None): Path to the gene interactions dataset @@ -96,8 +96,10 @@ def __init__( if study_index_path else None ) - self.v2g = ( - V2G.from_parquet(session, variant_gene_path) if variant_gene_path else None + self.variant_index = ( + VariantIndex.from_parquet(session, variant_index_path) + if variant_index_path + else None ) self.coloc = ( Colocalisation.from_parquet( @@ -107,7 +109,7 @@ def __init__( else None ) self.features_input_loader = L2GFeatureInputLoader( - v2g=self.v2g, + variant_index=self.variant_index, coloc=self.coloc, studies=self.studies, ) @@ -133,7 +135,7 @@ def run_predict(self) -> None: Raises: ValueError: If not all dependencies in prediction mode are set """ - if self.studies and self.v2g and self.coloc: + if self.studies and self.coloc: predictions = L2GPrediction.from_credible_set( self.session, self.credible_set, @@ -156,9 +158,9 @@ def run_train(self) -> None: if ( self.gs_curation and self.interactions - and self.v2g and self.wandb_run_name and self.model_path + and self.variant_index ): wandb_key = access_gcp_secret("wandb-key", "open-targets-genetics-dev") # Process gold standard and L2G features @@ -203,7 +205,12 @@ def _generate_feature_matrix(self, write_feature_matrix: bool) -> L2GFeatureMatr ValueError: If write_feature_matrix is set to True but a path is not provided. ValueError: If dependencies to build features are not set. """ - if self.gs_curation and self.interactions and self.v2g and self.studies: + if ( + self.gs_curation + and self.interactions + and self.studies + and self.variant_index + ): study_locus_overlap = StudyLocus( _df=self.credible_set.df.join( f.broadcast( @@ -229,7 +236,7 @@ def _generate_feature_matrix(self, write_feature_matrix: bool) -> L2GFeatureMatr gold_standards = L2GGoldStandard.from_otg_curation( gold_standard_curation=self.gs_curation, - v2g=self.v2g, + variant_index=self.variant_index, study_locus_overlap=study_locus_overlap, interactions=self.interactions, ) diff --git a/tests/gentropy/conftest.py b/tests/gentropy/conftest.py index 903a9d5b5..13b3f510f 100644 --- a/tests/gentropy/conftest.py +++ b/tests/gentropy/conftest.py @@ -24,7 +24,6 @@ from gentropy.dataset.study_locus import StudyLocus from gentropy.dataset.study_locus_overlap import StudyLocusOverlap from gentropy.dataset.summary_statistics import SummaryStatistics -from gentropy.dataset.v2g import V2G from gentropy.dataset.variant_index import VariantIndex from gentropy.datasource.eqtl_catalogue.finemapping import EqtlCatalogueFinemapping from gentropy.datasource.eqtl_catalogue.study_index import EqtlCatalogueStudyIndex @@ -250,31 +249,6 @@ def mock_intervals(spark: SparkSession) -> Intervals: return Intervals(_df=data_spec.build(), _schema=interval_schema) -@pytest.fixture() -def mock_v2g(spark: SparkSession) -> V2G: - """Mock v2g dataset.""" - v2g_schema = V2G.get_schema() - - data_spec = ( - dg.DataGenerator( - spark, - rows=400, - partitions=4, - randomSeedMethod="hash_fieldname", - ) - .withSchema(v2g_schema) - .withColumnSpec("distance", percentNulls=0.1) - .withColumnSpec("resourceScore", percentNulls=0.1) - .withColumnSpec("score", percentNulls=0.1) - .withColumnSpec("pmid", percentNulls=0.1) - .withColumnSpec("biofeature", percentNulls=0.1) - .withColumnSpec("variantFunctionalConsequenceId", percentNulls=0.1) - .withColumnSpec("isHighQualityPlof", percentNulls=0.1) - ) - - return V2G(_df=data_spec.build(), _schema=v2g_schema) - - @pytest.fixture() def mock_variant_consequence_to_score(spark: SparkSession) -> DataFrame: """Slice of the VEP consequence to score table.""" diff --git a/tests/gentropy/dataset/test_intervals.py b/tests/gentropy/dataset/test_intervals.py deleted file mode 100644 index 26d79acd1..000000000 --- a/tests/gentropy/dataset/test_intervals.py +++ /dev/null @@ -1,18 +0,0 @@ -"""Tests on LD index.""" - -from __future__ import annotations - -from typing import TYPE_CHECKING - -from gentropy.dataset.v2g import V2G - -if TYPE_CHECKING: - from gentropy.dataset.intervals import Intervals - from gentropy.dataset.variant_index import VariantIndex - - -def test_interval_v2g_creation( - mock_intervals: Intervals, mock_variant_index: VariantIndex -) -> None: - """Test creation of V2G from intervals.""" - assert isinstance(mock_intervals.v2g(mock_variant_index), V2G) diff --git a/tests/gentropy/datasource/open_targets/test_l2g_gold_standard.py b/tests/gentropy/datasource/open_targets/test_l2g_gold_standard.py index 347b7ec69..68a0bf9c6 100644 --- a/tests/gentropy/datasource/open_targets/test_l2g_gold_standard.py +++ b/tests/gentropy/datasource/open_targets/test_l2g_gold_standard.py @@ -6,11 +6,19 @@ import pytest from pyspark.sql import DataFrame +from pyspark.sql.types import ( + ArrayType, + IntegerType, + LongType, + StringType, + StructField, + StructType, +) from gentropy.dataset.l2g_feature_matrix import L2GFeatureMatrix from gentropy.dataset.l2g_gold_standard import L2GGoldStandard from gentropy.dataset.study_index import StudyIndex -from gentropy.dataset.v2g import V2G +from gentropy.dataset.variant_index import VariantIndex from gentropy.datasource.open_targets.l2g_gold_standard import ( OpenTargetsL2GGoldStandard, ) @@ -25,13 +33,13 @@ def test_open_targets_as_l2g_gold_standard( sample_l2g_gold_standard: DataFrame, - mock_v2g: V2G, + mock_variant_index: VariantIndex, ) -> None: """Test L2G gold standard from OTG curation.""" assert isinstance( OpenTargetsL2GGoldStandard.as_l2g_gold_standard( sample_l2g_gold_standard, - mock_v2g, + mock_variant_index, ), L2GGoldStandard, ) @@ -81,19 +89,41 @@ def _setup(self: TestExpandGoldStandardWithNegatives, spark: SparkSession) -> No ["variantId", "geneId", "studyId"], ) - sample_v2g_df = spark.createDataFrame( - [ - ("variant1", "gene1", 5, "X", "X", "X"), - ("variant1", "gene3", 10, "X", "X", "X"), - ], + sample_variant_index_df = spark.createDataFrame( [ - "variantId", - "geneId", - "distance", - "chromosome", - "datatypeId", - "datasourceId", + ( + "variant1", + "chrom", + 1, + "A", + "T", + [ + {"distanceFromTss": 5, "targetId": "gene1"}, + {"distanceFromTss": 10, "targetId": "gene3"}, + ], + ), ], + StructType( + [ + StructField("variantId", StringType(), True), + StructField("chromosome", StringType(), True), + StructField("position", IntegerType(), True), + StructField("referenceAllele", StringType(), True), + StructField("alternateAllele", StringType(), True), + StructField( + "transcriptConsequences", + ArrayType( + StructType( + [ + StructField("distanceFromTss", LongType(), True), + StructField("targetId", StringType(), True), + ] + ) + ), + True, + ), + ] + ), ) self.expected_expanded_gs = spark.createDataFrame( @@ -107,7 +137,9 @@ def _setup(self: TestExpandGoldStandardWithNegatives, spark: SparkSession) -> No self.observed_df = ( OpenTargetsL2GGoldStandard.expand_gold_standard_with_negatives( self.sample_positive_set, - V2G(_df=sample_v2g_df, _schema=V2G.get_schema()), + VariantIndex( + _df=sample_variant_index_df, _schema=VariantIndex.get_schema() + ), ) ) diff --git a/tests/gentropy/test_schemas.py b/tests/gentropy/test_schemas.py index 1af72c149..4cafa4466 100644 --- a/tests/gentropy/test_schemas.py +++ b/tests/gentropy/test_schemas.py @@ -16,7 +16,7 @@ from _pytest.fixtures import FixtureRequest from gentropy.dataset.gene_index import GeneIndex - from gentropy.dataset.v2g import V2G + from gentropy.dataset.l2g_prediction import L2GPrediction SCHEMA_DIR = "src/gentropy/assets/schemas" @@ -73,21 +73,23 @@ def test_schema_columns_camelcase(schema_json: str) -> None: class TestValidateSchema: - """Test validate_schema method using V2G (unnested) and GeneIndex (nested) as a testing dataset.""" + """Test validate_schema method using L2GPrediction (unnested) and GeneIndex (nested) as a testing dataset.""" @pytest.fixture() def mock_dataset_instance( self: TestValidateSchema, request: FixtureRequest - ) -> V2G | GeneIndex: + ) -> L2GPrediction | GeneIndex: """Meta fixture to return the value of any requested fixture.""" return request.getfixturevalue(request.param) @pytest.mark.parametrize( - "mock_dataset_instance", ["mock_v2g", "mock_gene_index"], indirect=True + "mock_dataset_instance", + ["mock_l2g_predictions", "mock_gene_index"], + indirect=True, ) def test_validate_schema_extra_field( self: TestValidateSchema, - mock_dataset_instance: V2G | GeneIndex, + mock_dataset_instance: L2GPrediction | GeneIndex, ) -> None: """Test that validate_schema raises an error if the observed schema has an extra field.""" with pytest.raises(ValueError, match="extraField"): @@ -96,22 +98,26 @@ def test_validate_schema_extra_field( ) @pytest.mark.parametrize( - "mock_dataset_instance", ["mock_v2g", "mock_gene_index"], indirect=True + "mock_dataset_instance", + ["mock_l2g_predictions", "mock_gene_index"], + indirect=True, ) def test_validate_schema_missing_field( self: TestValidateSchema, - mock_dataset_instance: V2G | GeneIndex, + mock_dataset_instance: L2GPrediction | GeneIndex, ) -> None: """Test that validate_schema raises an error if the observed schema is missing a required field, geneId in this case.""" with pytest.raises(ValueError, match="geneId"): mock_dataset_instance.df = mock_dataset_instance.df.drop("geneId") @pytest.mark.parametrize( - "mock_dataset_instance", ["mock_v2g", "mock_gene_index"], indirect=True + "mock_dataset_instance", + ["mock_l2g_predictions", "mock_gene_index"], + indirect=True, ) def test_validate_schema_duplicated_field( self: TestValidateSchema, - mock_dataset_instance: V2G | GeneIndex, + mock_dataset_instance: L2GPrediction | GeneIndex, ) -> None: """Test that validate_schema raises an error if the observed schema has a duplicated field, geneId in this case.""" with pytest.raises(ValueError, match="geneId"): @@ -120,11 +126,13 @@ def test_validate_schema_duplicated_field( ) @pytest.mark.parametrize( - "mock_dataset_instance", ["mock_v2g", "mock_gene_index"], indirect=True + "mock_dataset_instance", + ["mock_l2g_predictions", "mock_gene_index"], + indirect=True, ) def test_validate_schema_different_datatype( self: TestValidateSchema, - mock_dataset_instance: V2G | GeneIndex, + mock_dataset_instance: L2GPrediction | GeneIndex, ) -> None: """Test that validate_schema raises an error if any field in the observed schema has a different type than expected.""" with pytest.raises(ValueError, match="geneId"): From 95793c60fa097b6cd1d16f78862c296cf15d018e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Irene=20L=C3=B3pez?= Date: Wed, 18 Sep 2024 17:12:47 +0100 Subject: [PATCH 34/48] chore: drop config yamls --- config/step/ot_locus_to_gene_predict.yaml | 10 ------- config/step/ot_locus_to_gene_train.yaml | 36 ----------------------- src/gentropy/config.py | 15 +++++++++- 3 files changed, 14 insertions(+), 47 deletions(-) delete mode 100644 config/step/ot_locus_to_gene_predict.yaml delete mode 100644 config/step/ot_locus_to_gene_train.yaml diff --git a/config/step/ot_locus_to_gene_predict.yaml b/config/step/ot_locus_to_gene_predict.yaml deleted file mode 100644 index 97080223a..000000000 --- a/config/step/ot_locus_to_gene_predict.yaml +++ /dev/null @@ -1,10 +0,0 @@ -defaults: - - locus_to_gene - -run_mode: predict -model_path: null -predictions_path: ${datasets.l2g_predictions} -credible_set_path: ${datasets.credible_set} -variant_gene_path: ${datasets.variant_to_gene} -colocalisation_path: ${datasets.colocalisation} -study_index_path: ${datasets.study_index} diff --git a/config/step/ot_locus_to_gene_train.yaml b/config/step/ot_locus_to_gene_train.yaml deleted file mode 100644 index 181e1303b..000000000 --- a/config/step/ot_locus_to_gene_train.yaml +++ /dev/null @@ -1,36 +0,0 @@ -defaults: - - locus_to_gene - -run_mode: train -wandb_run_name: null -hf_hub_repo_id: opentargets/locus_to_gene -model_path: ${datasets.l2g_model} -predictions_path: ${datasets.l2g_predictions} -credible_set_path: ${datasets.credible_set} -variant_gene_path: ${datasets.variant_to_gene} -colocalisation_path: ${datasets.colocalisation} -study_index_path: ${datasets.study_index} -gold_standard_curation_path: ${datasets.l2g_gold_standard_curation} -gene_interactions_path: ${datasets.gene_interactions} -feature_matrix_path: ${datasets.l2g_feature_matrix} -write_feature_matrix: true -hyperparameters: - n_estimators: 100 - max_depth: 5 - loss: log_loss -download_from_hub: true -features_list: - # average distance of all tagging variants to gene TSS - # - distanceTssMean - # minimum distance of all tagging variants to gene TSS - # - distanceTssMinimum - # max CLPP for each (study, locus, gene) aggregating over a specific qtl type - - eQtlColocClppMaximum - - pQtlColocClppMaximum - - sQtlColocClppMaximum - - tuQtlColocClppMaximum - # max H4 for each (study, locus, gene) aggregating over a specific qtl type - - eQtlColocH4Maximum - - pQtlColocH4Maximum - - sQtlColocH4Maximum - - tuQtlColocH4Maximum diff --git a/src/gentropy/config.py b/src/gentropy/config.py index aa8e331af..d5e02924b 100644 --- a/src/gentropy/config.py +++ b/src/gentropy/config.py @@ -225,7 +225,20 @@ class LocusToGeneConfig(StepConfig): feature_matrix_path: str | None = None gold_standard_curation_path: str | None = None gene_interactions_path: str | None = None - features_list: list[dict[str, str]] = MISSING + features_list: list[str] = field( + default_factory=lambda: [ + # max CLPP for each (study, locus, gene) aggregating over a specific qtl type + "eQtlColocClppMaximum", + "pQtlColocClppMaximum", + "sQtlColocClppMaximum", + "tuQtlColocClppMaximum", + # max H4 for each (study, locus, gene) aggregating over a specific qtl type + "eQtlColocH4Maximum", + "pQtlColocH4Maximum", + "sQtlColocH4Maximum", + "tuQtlColocH4Maximum", + ] + ) hyperparameters: dict[str, Any] = field( default_factory=lambda: { "n_estimators": 100, From cb5c16920b1014980ed45e291e87453982be86e4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Irene=20L=C3=B3pez?= Date: Wed, 18 Sep 2024 17:34:05 +0100 Subject: [PATCH 35/48] refactor: move feature classes to datasets module --- src/gentropy/dataset/l2g_feature.py | 432 ++++++++++++++++- src/gentropy/method/l2g/feature_factory.py | 450 +----------------- tests/gentropy/dataset/test_l2g_feature.py | 57 +++ .../method/test_l2g/test_feature_factory.py | 193 -------- 4 files changed, 502 insertions(+), 630 deletions(-) create mode 100644 tests/gentropy/dataset/test_l2g_feature.py delete mode 100644 tests/gentropy/method/test_l2g/test_feature_factory.py diff --git a/src/gentropy/dataset/l2g_feature.py b/src/gentropy/dataset/l2g_feature.py index 765bb5d55..8a1baddce 100644 --- a/src/gentropy/dataset/l2g_feature.py +++ b/src/gentropy/dataset/l2g_feature.py @@ -1,4 +1,4 @@ -"""L2G Feature Dataset.""" +"""L2G Feature Dataset with a collection of methods that extract features from the gentropy datasets to be fed in L2G.""" from __future__ import annotations @@ -6,10 +6,17 @@ from dataclasses import dataclass from typing import TYPE_CHECKING, Any +import pyspark.sql.functions as f + from gentropy.common.schemas import parse_spark_schema +from gentropy.common.spark_helpers import convert_from_wide_to_long +from gentropy.dataset.colocalisation import Colocalisation from gentropy.dataset.dataset import Dataset +from gentropy.dataset.study_index import StudyIndex +from gentropy.dataset.v2g import V2G if TYPE_CHECKING: + from pyspark.sql import DataFrame from pyspark.sql.types import StructType from gentropy.dataset.l2g_gold_standard import L2GGoldStandard @@ -60,3 +67,426 @@ def compute( L2GFeature: a L2GFeature dataset """ pass + + +def _common_colocalisation_feature_logic( + study_loci_to_annotate: StudyLocus | L2GGoldStandard, + colocalisation_method: str, + colocalisation_metric: str, + feature_name: str, + qtl_type: str, + *, + colocalisation: Colocalisation, + study_index: StudyIndex, +) -> DataFrame: + """Wrapper to call the logic that creates a type of colocalisation features. + + Args: + study_loci_to_annotate (StudyLocus | L2GGoldStandard): The dataset containing study loci that will be used for annotation + colocalisation_method (str): The colocalisation method to filter the data by + colocalisation_metric (str): The colocalisation metric to use + feature_name (str): The name of the feature to create + qtl_type (str): The type of QTL to filter the data by + colocalisation (Colocalisation): Dataset with the colocalisation results + study_index (StudyIndex): Study index to fetch study type and gene + + Returns: + DataFrame: Feature annotation in long format with the columns: studyLocusId, geneId, featureName, featureValue + """ + return convert_from_wide_to_long( + colocalisation.extract_maximum_coloc_probability_per_region_and_gene( + study_loci_to_annotate, + study_index, + filter_by_colocalisation_method=colocalisation_method, + filter_by_qtl=qtl_type, + ).selectExpr( + "studyLocusId", + "geneId", + f"{colocalisation_metric} as {feature_name}", + ), + id_vars=("studyLocusId", "geneId"), + var_name="featureName", + value_name="featureValue", + ) + + +class EQtlColocClppMaximumFeature(L2GFeature): + """Max CLPP for each (study, locus, gene) aggregating over all eQTLs.""" + + feature_dependency_type = [Colocalisation, StudyIndex] + feature_name = "eQtlColocClppMaximum" + + @classmethod + def compute( + cls: type[EQtlColocClppMaximumFeature], + study_loci_to_annotate: StudyLocus | L2GGoldStandard, + feature_dependency: dict[str, Any], + ) -> EQtlColocClppMaximumFeature: + """Computes the feature. + + Args: + study_loci_to_annotate (StudyLocus | L2GGoldStandard): The dataset containing study loci that will be used for annotation + feature_dependency (dict[str, Any]): Dictionary with the dependencies required. They are passed as keyword arguments. + + Returns: + EQtlColocClppMaximumFeature: Feature dataset + """ + colocalisation_method = "ECaviar" + colocalisation_metric = "clpp" + qtl_type = "eqtl" + + return cls( + _df=_common_colocalisation_feature_logic( + study_loci_to_annotate, + colocalisation_method, + colocalisation_metric, + cls.feature_name, + qtl_type, + **feature_dependency, + ), + _schema=cls.get_schema(), + ) + + +class PQtlColocClppMaximumFeature(L2GFeature): + """Max CLPP for each (study, locus, gene) aggregating over all pQTLs.""" + + feature_dependency_type = [Colocalisation, StudyIndex] + feature_name = "pQtlColocClppMaximum" + + @classmethod + def compute( + cls: type[PQtlColocClppMaximumFeature], + study_loci_to_annotate: StudyLocus | L2GGoldStandard, + feature_dependency: dict[str, Any], + ) -> PQtlColocClppMaximumFeature: + """Computes the feature. + + Args: + study_loci_to_annotate (StudyLocus | L2GGoldStandard): The dataset containing study loci that will be used for annotation + feature_dependency (dict[str, Any]): Dataset with the colocalisation results + + Returns: + PQtlColocClppMaximumFeature: Feature dataset + """ + colocalisation_method = "ECaviar" + colocalisation_metric = "clpp" + qtl_type = "pqtl" + return cls( + _df=_common_colocalisation_feature_logic( + study_loci_to_annotate, + colocalisation_method, + colocalisation_metric, + cls.feature_name, + qtl_type, + **feature_dependency, + ), + _schema=cls.get_schema(), + ) + + +class SQtlColocClppMaximumFeature(L2GFeature): + """Max CLPP for each (study, locus, gene) aggregating over all sQTLs.""" + + feature_dependency_type = [Colocalisation, StudyIndex] + feature_name = "sQtlColocClppMaximum" + + @classmethod + def compute( + cls: type[SQtlColocClppMaximumFeature], + study_loci_to_annotate: StudyLocus | L2GGoldStandard, + feature_dependency: dict[str, Any], + ) -> SQtlColocClppMaximumFeature: + """Computes the feature. + + Args: + study_loci_to_annotate (StudyLocus | L2GGoldStandard): The dataset containing study loci that will be used for annotation + feature_dependency (dict[str, Any]): Dataset with the colocalisation results + + Returns: + SQtlColocClppMaximumFeature: Feature dataset + """ + colocalisation_method = "ECaviar" + colocalisation_metric = "clpp" + qtl_type = "sqtl" + return cls( + _df=_common_colocalisation_feature_logic( + study_loci_to_annotate, + colocalisation_method, + colocalisation_metric, + cls.feature_name, + qtl_type, + **feature_dependency, + ), + _schema=cls.get_schema(), + ) + + +class TuQtlColocClppMaximumFeature(L2GFeature): + """Max CLPP for each (study, locus, gene) aggregating over all tuQTLs.""" + + feature_dependency_type = [Colocalisation, StudyIndex] + feature_name = "tuQtlColocClppMaximum" + + @classmethod + def compute( + cls: type[TuQtlColocClppMaximumFeature], + study_loci_to_annotate: StudyLocus | L2GGoldStandard, + feature_dependency: dict[str, Any], + ) -> TuQtlColocClppMaximumFeature: + """Computes the feature. + + Args: + study_loci_to_annotate (StudyLocus | L2GGoldStandard): The dataset containing study loci that will be used for annotation + feature_dependency (dict[str, Any]): Dataset with the colocalisation results + + Returns: + TuQtlColocClppMaximumFeature: Feature dataset + """ + colocalisation_method = "ECaviar" + colocalisation_metric = "clpp" + qtl_type = "tuqtl" + return cls( + _df=_common_colocalisation_feature_logic( + study_loci_to_annotate, + colocalisation_method, + colocalisation_metric, + cls.feature_name, + qtl_type, + **feature_dependency, + ), + _schema=cls.get_schema(), + ) + + +class EQtlColocH4MaximumFeature(L2GFeature): + """Max CLPP for each (study, locus, gene) aggregating over all eQTLs.""" + + feature_dependency_type = [Colocalisation, StudyIndex] + feature_name = "eQtlColocH4Maximum" + + @classmethod + def compute( + cls: type[EQtlColocH4MaximumFeature], + study_loci_to_annotate: StudyLocus | L2GGoldStandard, + feature_dependency: dict[str, Any], + ) -> EQtlColocH4MaximumFeature: + """Computes the feature. + + Args: + study_loci_to_annotate (StudyLocus | L2GGoldStandard): The dataset containing study loci that will be used for annotation + feature_dependency (dict[str, Any]): Dataset with the colocalisation results + + Returns: + EQtlColocH4MaximumFeature: Feature dataset + """ + colocalisation_method = "Coloc" + colocalisation_metric = "h4" + qtl_type = "eqtl" + return cls( + _df=_common_colocalisation_feature_logic( + study_loci_to_annotate, + colocalisation_method, + colocalisation_metric, + cls.feature_name, + qtl_type, + **feature_dependency, + ), + _schema=cls.get_schema(), + ) + + +class PQtlColocH4MaximumFeature(L2GFeature): + """Max CLPP for each (study, locus, gene) aggregating over all pQTLs.""" + + feature_dependency_type = [Colocalisation, StudyIndex] + feature_name = "pQtlColocH4Maximum" + + @classmethod + def compute( + cls: type[PQtlColocH4MaximumFeature], + study_loci_to_annotate: StudyLocus | L2GGoldStandard, + feature_dependency: dict[str, Any], + ) -> PQtlColocH4MaximumFeature: + """Computes the feature. + + Args: + study_loci_to_annotate (StudyLocus | L2GGoldStandard): The dataset containing study loci that will be used for annotation + feature_dependency (dict[str, Any]): Dataset with the colocalisation results + + Returns: + PQtlColocH4MaximumFeature: Feature dataset + """ + colocalisation_method = "Coloc" + colocalisation_metric = "h4" + qtl_type = "pqtl" + return cls( + _df=_common_colocalisation_feature_logic( + study_loci_to_annotate, + colocalisation_method, + colocalisation_metric, + cls.feature_name, + qtl_type, + **feature_dependency, + ), + _schema=cls.get_schema(), + ) + + +class SQtlColocH4MaximumFeature(L2GFeature): + """Max CLPP for each (study, locus, gene) aggregating over all sQTLs.""" + + feature_dependency_type = [Colocalisation, StudyIndex] + feature_name = "sQtlColocH4Maximum" + + @classmethod + def compute( + cls: type[SQtlColocH4MaximumFeature], + study_loci_to_annotate: StudyLocus | L2GGoldStandard, + feature_dependency: dict[str, Any], + ) -> SQtlColocH4MaximumFeature: + """Computes the feature. + + Args: + study_loci_to_annotate (StudyLocus | L2GGoldStandard): The dataset containing study loci that will be used for annotation + feature_dependency (dict[str, Any]): Dataset with the colocalisation results + + Returns: + SQtlColocH4MaximumFeature: Feature dataset + """ + colocalisation_method = "Coloc" + colocalisation_metric = "h4" + qtl_type = "sqtl" + return cls( + _df=_common_colocalisation_feature_logic( + study_loci_to_annotate, + colocalisation_method, + colocalisation_metric, + cls.feature_name, + qtl_type, + **feature_dependency, + ), + _schema=cls.get_schema(), + ) + + +class TuQtlColocH4MaximumFeature(L2GFeature): + """Max H4 for each (study, locus, gene) aggregating over all tuQTLs.""" + + feature_dependency_type = [Colocalisation, StudyIndex] + feature_name = "tuQtlColocH4Maximum" + + @classmethod + def compute( + cls: type[TuQtlColocH4MaximumFeature], + study_loci_to_annotate: StudyLocus | L2GGoldStandard, + feature_dependency: dict[str, Any], + ) -> TuQtlColocH4MaximumFeature: + """Computes the feature. + + Args: + study_loci_to_annotate (StudyLocus | L2GGoldStandard): The dataset containing study loci that will be used for annotation + feature_dependency (dict[str, Any]): Dataset with the colocalisation results + + Returns: + TuQtlColocH4MaximumFeature: Feature dataset + """ + colocalisation_method = "Coloc" + colocalisation_metric = "h4" + qtl_type = "tuqtl" + return cls( + _df=_common_colocalisation_feature_logic( + study_loci_to_annotate, + colocalisation_method, + colocalisation_metric, + cls.feature_name, + qtl_type, + **feature_dependency, + ), + _schema=cls.get_schema(), + ) + + +class DistanceTssMinimumFeature(L2GFeature): + """Minimum distance of all tagging variants to gene TSS.""" + + @classmethod + def compute( + cls: type[DistanceTssMinimumFeature], + study_loci_to_annotate: StudyLocus | L2GGoldStandard, + feature_dependency: V2G, + ) -> L2GFeature: + """Computes the feature. + + Args: + study_loci_to_annotate (StudyLocus | L2GGoldStandard): The dataset containing study loci that will be used for annotation + feature_dependency (V2G): Dataset that contains the distance information + + Returns: + L2GFeature: Feature dataset + + Raises: + NotImplementedError: Not implemented + """ + raise NotImplementedError + + +class DistanceTssMeanFeature(L2GFeature): + """Average distance of all tagging variants to gene TSS. + + NOTE: to be rewritten taking variant index as input + """ + + fill_na_value = 500_000 + feature_dependency_type = V2G + + @classmethod + def compute( + cls: type[DistanceTssMeanFeature], + study_loci_to_annotate: StudyLocus | L2GGoldStandard, + feature_dependency: V2G, + ) -> DistanceTssMeanFeature: + """Computes the feature. + + Args: + study_loci_to_annotate (StudyLocus | L2GGoldStandard): The dataset containing study loci that will be used for annotation + feature_dependency (V2G): Dataset that contains the distance information + + Returns: + DistanceTssMeanFeature: Feature dataset + """ + agg_expr = f.mean("weightedScore").alias("distanceTssMean") + # Everything but expresion is common logic + v2g = feature_dependency.df.filter(f.col("datasourceId") == "canonical_tss") + wide_df = ( + study_loci_to_annotate.df.withColumn( + "variantInLocus", f.explode_outer("locus") + ) + .select( + "studyLocusId", + f.col("variantInLocus.variantId").alias("variantInLocusId"), + f.col("variantInLocus.posteriorProbability").alias( + "variantInLocusPosteriorProbability" + ), + ) + .join( + v2g.selectExpr("variantId as variantInLocusId", "geneId", "score"), + on="variantInLocusId", + how="inner", + ) + .withColumn( + "weightedScore", + f.col("score") * f.col("variantInLocusPosteriorProbability"), + ) + .groupBy("studyLocusId", "geneId") + .agg(agg_expr) + ) + return cls( + _df=convert_from_wide_to_long( + wide_df, + id_vars=("studyLocusId", "geneId"), + var_name="featureName", + value_name="featureValue", + ), + _schema=cls.get_schema(), + ) diff --git a/src/gentropy/method/l2g/feature_factory.py b/src/gentropy/method/l2g/feature_factory.py index 9495ff905..c0f0ef9b4 100644 --- a/src/gentropy/method/l2g/feature_factory.py +++ b/src/gentropy/method/l2g/feature_factory.py @@ -1,21 +1,22 @@ -"""Collection of methods that extract features from the gentropy datasets to be fed in L2G.""" +"""Factory that computes features based on an input list.""" from __future__ import annotations -from typing import TYPE_CHECKING, Any, Iterator, Mapping - -import pyspark.sql.functions as f - -from gentropy.common.spark_helpers import convert_from_wide_to_long -from gentropy.dataset.colocalisation import Colocalisation -from gentropy.dataset.l2g_feature import L2GFeature +from typing import Any, Iterator, Mapping + +from gentropy.dataset.l2g_feature import ( + EQtlColocClppMaximumFeature, + EQtlColocH4MaximumFeature, + L2GFeature, + PQtlColocClppMaximumFeature, + PQtlColocH4MaximumFeature, + SQtlColocClppMaximumFeature, + SQtlColocH4MaximumFeature, + TuQtlColocClppMaximumFeature, + TuQtlColocH4MaximumFeature, +) from gentropy.dataset.l2g_gold_standard import L2GGoldStandard -from gentropy.dataset.study_index import StudyIndex from gentropy.dataset.study_locus import StudyLocus -from gentropy.dataset.v2g import V2G - -if TYPE_CHECKING: - from pyspark.sql import DataFrame class L2GFeatureInputLoader: @@ -70,429 +71,6 @@ def __repr__(self) -> str: return repr(self.input_dependencies) -def _common_colocalisation_feature_logic( - study_loci_to_annotate: StudyLocus | L2GGoldStandard, - colocalisation_method: str, - colocalisation_metric: str, - feature_name: str, - qtl_type: str, - *, - colocalisation: Colocalisation, - study_index: StudyIndex, -) -> DataFrame: - """Wrapper to call the logic that creates a type of colocalisation features. - - Args: - study_loci_to_annotate (StudyLocus | L2GGoldStandard): The dataset containing study loci that will be used for annotation - colocalisation_method (str): The colocalisation method to filter the data by - colocalisation_metric (str): The colocalisation metric to use - feature_name (str): The name of the feature to create - qtl_type (str): The type of QTL to filter the data by - colocalisation (Colocalisation): Dataset with the colocalisation results - study_index (StudyIndex): Study index to fetch study type and gene - - Returns: - DataFrame: Feature annotation in long format with the columns: studyLocusId, geneId, featureName, featureValue - """ - return convert_from_wide_to_long( - colocalisation.extract_maximum_coloc_probability_per_region_and_gene( - study_loci_to_annotate, - study_index, - filter_by_colocalisation_method=colocalisation_method, - filter_by_qtl=qtl_type, - ).selectExpr( - "studyLocusId", - "geneId", - f"{colocalisation_metric} as {feature_name}", - ), - id_vars=("studyLocusId", "geneId"), - var_name="featureName", - value_name="featureValue", - ) - - -class EQtlColocClppMaximumFeature(L2GFeature): - """Max CLPP for each (study, locus, gene) aggregating over all eQTLs.""" - - feature_dependency_type = [Colocalisation, StudyIndex] - feature_name = "eQtlColocClppMaximum" - - @classmethod - def compute( - cls: type[EQtlColocClppMaximumFeature], - study_loci_to_annotate: StudyLocus | L2GGoldStandard, - feature_dependency: dict[str, Any], - ) -> EQtlColocClppMaximumFeature: - """Computes the feature. - - Args: - study_loci_to_annotate (StudyLocus | L2GGoldStandard): The dataset containing study loci that will be used for annotation - feature_dependency (dict[str, Any]): Dictionary with the dependencies required. They are passed as keyword arguments. - - Returns: - EQtlColocClppMaximumFeature: Feature dataset - """ - colocalisation_method = "ECaviar" - colocalisation_metric = "clpp" - qtl_type = "eqtl" - - return cls( - _df=_common_colocalisation_feature_logic( - study_loci_to_annotate, - colocalisation_method, - colocalisation_metric, - cls.feature_name, - qtl_type, - **feature_dependency, - ), - _schema=cls.get_schema(), - ) - - -class PQtlColocClppMaximumFeature(L2GFeature): - """Max CLPP for each (study, locus, gene) aggregating over all pQTLs.""" - - feature_dependency_type = [Colocalisation, StudyIndex] - feature_name = "pQtlColocClppMaximum" - - @classmethod - def compute( - cls: type[PQtlColocClppMaximumFeature], - study_loci_to_annotate: StudyLocus | L2GGoldStandard, - feature_dependency: dict[str, Any], - ) -> PQtlColocClppMaximumFeature: - """Computes the feature. - - Args: - study_loci_to_annotate (StudyLocus | L2GGoldStandard): The dataset containing study loci that will be used for annotation - feature_dependency (dict[str, Any]): Dataset with the colocalisation results - - Returns: - PQtlColocClppMaximumFeature: Feature dataset - """ - colocalisation_method = "ECaviar" - colocalisation_metric = "clpp" - qtl_type = "pqtl" - return cls( - _df=_common_colocalisation_feature_logic( - study_loci_to_annotate, - colocalisation_method, - colocalisation_metric, - cls.feature_name, - qtl_type, - **feature_dependency, - ), - _schema=cls.get_schema(), - ) - - -class SQtlColocClppMaximumFeature(L2GFeature): - """Max CLPP for each (study, locus, gene) aggregating over all sQTLs.""" - - feature_dependency_type = [Colocalisation, StudyIndex] - feature_name = "sQtlColocClppMaximum" - - @classmethod - def compute( - cls: type[SQtlColocClppMaximumFeature], - study_loci_to_annotate: StudyLocus | L2GGoldStandard, - feature_dependency: dict[str, Any], - ) -> SQtlColocClppMaximumFeature: - """Computes the feature. - - Args: - study_loci_to_annotate (StudyLocus | L2GGoldStandard): The dataset containing study loci that will be used for annotation - feature_dependency (dict[str, Any]): Dataset with the colocalisation results - - Returns: - SQtlColocClppMaximumFeature: Feature dataset - """ - colocalisation_method = "ECaviar" - colocalisation_metric = "clpp" - qtl_type = "sqtl" - return cls( - _df=_common_colocalisation_feature_logic( - study_loci_to_annotate, - colocalisation_method, - colocalisation_metric, - cls.feature_name, - qtl_type, - **feature_dependency, - ), - _schema=cls.get_schema(), - ) - - -class TuQtlColocClppMaximumFeature(L2GFeature): - """Max CLPP for each (study, locus, gene) aggregating over all tuQTLs.""" - - feature_dependency_type = [Colocalisation, StudyIndex] - feature_name = "tuQtlColocClppMaximum" - - @classmethod - def compute( - cls: type[TuQtlColocClppMaximumFeature], - study_loci_to_annotate: StudyLocus | L2GGoldStandard, - feature_dependency: dict[str, Any], - ) -> TuQtlColocClppMaximumFeature: - """Computes the feature. - - Args: - study_loci_to_annotate (StudyLocus | L2GGoldStandard): The dataset containing study loci that will be used for annotation - feature_dependency (dict[str, Any]): Dataset with the colocalisation results - - Returns: - TuQtlColocClppMaximumFeature: Feature dataset - """ - colocalisation_method = "ECaviar" - colocalisation_metric = "clpp" - qtl_type = "tuqtl" - return cls( - _df=_common_colocalisation_feature_logic( - study_loci_to_annotate, - colocalisation_method, - colocalisation_metric, - cls.feature_name, - qtl_type, - **feature_dependency, - ), - _schema=cls.get_schema(), - ) - - -class EQtlColocH4MaximumFeature(L2GFeature): - """Max CLPP for each (study, locus, gene) aggregating over all eQTLs.""" - - feature_dependency_type = [Colocalisation, StudyIndex] - feature_name = "eQtlColocH4Maximum" - - @classmethod - def compute( - cls: type[EQtlColocH4MaximumFeature], - study_loci_to_annotate: StudyLocus | L2GGoldStandard, - feature_dependency: dict[str, Any], - ) -> EQtlColocH4MaximumFeature: - """Computes the feature. - - Args: - study_loci_to_annotate (StudyLocus | L2GGoldStandard): The dataset containing study loci that will be used for annotation - feature_dependency (dict[str, Any]): Dataset with the colocalisation results - - Returns: - EQtlColocH4MaximumFeature: Feature dataset - """ - colocalisation_method = "Coloc" - colocalisation_metric = "h4" - qtl_type = "eqtl" - return cls( - _df=_common_colocalisation_feature_logic( - study_loci_to_annotate, - colocalisation_method, - colocalisation_metric, - cls.feature_name, - qtl_type, - **feature_dependency, - ), - _schema=cls.get_schema(), - ) - - -class PQtlColocH4MaximumFeature(L2GFeature): - """Max CLPP for each (study, locus, gene) aggregating over all pQTLs.""" - - feature_dependency_type = [Colocalisation, StudyIndex] - feature_name = "pQtlColocH4Maximum" - - @classmethod - def compute( - cls: type[PQtlColocH4MaximumFeature], - study_loci_to_annotate: StudyLocus | L2GGoldStandard, - feature_dependency: dict[str, Any], - ) -> PQtlColocH4MaximumFeature: - """Computes the feature. - - Args: - study_loci_to_annotate (StudyLocus | L2GGoldStandard): The dataset containing study loci that will be used for annotation - feature_dependency (dict[str, Any]): Dataset with the colocalisation results - - Returns: - PQtlColocH4MaximumFeature: Feature dataset - """ - colocalisation_method = "Coloc" - colocalisation_metric = "h4" - qtl_type = "pqtl" - return cls( - _df=_common_colocalisation_feature_logic( - study_loci_to_annotate, - colocalisation_method, - colocalisation_metric, - cls.feature_name, - qtl_type, - **feature_dependency, - ), - _schema=cls.get_schema(), - ) - - -class SQtlColocH4MaximumFeature(L2GFeature): - """Max CLPP for each (study, locus, gene) aggregating over all sQTLs.""" - - feature_dependency_type = [Colocalisation, StudyIndex] - feature_name = "sQtlColocH4Maximum" - - @classmethod - def compute( - cls: type[SQtlColocH4MaximumFeature], - study_loci_to_annotate: StudyLocus | L2GGoldStandard, - feature_dependency: dict[str, Any], - ) -> SQtlColocH4MaximumFeature: - """Computes the feature. - - Args: - study_loci_to_annotate (StudyLocus | L2GGoldStandard): The dataset containing study loci that will be used for annotation - feature_dependency (dict[str, Any]): Dataset with the colocalisation results - - Returns: - SQtlColocH4MaximumFeature: Feature dataset - """ - colocalisation_method = "Coloc" - colocalisation_metric = "h4" - qtl_type = "sqtl" - return cls( - _df=_common_colocalisation_feature_logic( - study_loci_to_annotate, - colocalisation_method, - colocalisation_metric, - cls.feature_name, - qtl_type, - **feature_dependency, - ), - _schema=cls.get_schema(), - ) - - -class TuQtlColocH4MaximumFeature(L2GFeature): - """Max H4 for each (study, locus, gene) aggregating over all tuQTLs.""" - - feature_dependency_type = [Colocalisation, StudyIndex] - feature_name = "tuQtlColocH4Maximum" - - @classmethod - def compute( - cls: type[TuQtlColocH4MaximumFeature], - study_loci_to_annotate: StudyLocus | L2GGoldStandard, - feature_dependency: dict[str, Any], - ) -> TuQtlColocH4MaximumFeature: - """Computes the feature. - - Args: - study_loci_to_annotate (StudyLocus | L2GGoldStandard): The dataset containing study loci that will be used for annotation - feature_dependency (dict[str, Any]): Dataset with the colocalisation results - - Returns: - TuQtlColocH4MaximumFeature: Feature dataset - """ - colocalisation_method = "Coloc" - colocalisation_metric = "h4" - qtl_type = "tuqtl" - return cls( - _df=_common_colocalisation_feature_logic( - study_loci_to_annotate, - colocalisation_method, - colocalisation_metric, - cls.feature_name, - qtl_type, - **feature_dependency, - ), - _schema=cls.get_schema(), - ) - - -class DistanceTssMinimumFeature(L2GFeature): - """Minimum distance of all tagging variants to gene TSS.""" - - @classmethod - def compute( - cls: type[DistanceTssMinimumFeature], - study_loci_to_annotate: StudyLocus | L2GGoldStandard, - feature_dependency: V2G, - ) -> L2GFeature: - """Computes the feature. - - Args: - study_loci_to_annotate (StudyLocus | L2GGoldStandard): The dataset containing study loci that will be used for annotation - feature_dependency (V2G): Dataset that contains the distance information - - Returns: - L2GFeature: Feature dataset - - Raises: - NotImplementedError: Not implemented - """ - raise NotImplementedError - - -class DistanceTssMeanFeature(L2GFeature): - """Average distance of all tagging variants to gene TSS. - - NOTE: to be rewritten taking variant index as input - """ - - fill_na_value = 500_000 - feature_dependency_type = V2G - - @classmethod - def compute( - cls: type[DistanceTssMeanFeature], - study_loci_to_annotate: StudyLocus | L2GGoldStandard, - feature_dependency: V2G, - ) -> DistanceTssMeanFeature: - """Computes the feature. - - Args: - study_loci_to_annotate (StudyLocus | L2GGoldStandard): The dataset containing study loci that will be used for annotation - feature_dependency (V2G): Dataset that contains the distance information - - Returns: - DistanceTssMeanFeature: Feature dataset - """ - agg_expr = f.mean("weightedScore").alias("distanceTssMean") - # Everything but expresion is common logic - v2g = feature_dependency.df.filter(f.col("datasourceId") == "canonical_tss") - wide_df = ( - study_loci_to_annotate.df.withColumn( - "variantInLocus", f.explode_outer("locus") - ) - .select( - "studyLocusId", - f.col("variantInLocus.variantId").alias("variantInLocusId"), - f.col("variantInLocus.posteriorProbability").alias( - "variantInLocusPosteriorProbability" - ), - ) - .join( - v2g.selectExpr("variantId as variantInLocusId", "geneId", "score"), - on="variantInLocusId", - how="inner", - ) - .withColumn( - "weightedScore", - f.col("score") * f.col("variantInLocusPosteriorProbability"), - ) - .groupBy("studyLocusId", "geneId") - .agg(agg_expr) - ) - return cls( - _df=convert_from_wide_to_long( - wide_df, - id_vars=("studyLocusId", "geneId"), - var_name="featureName", - value_name="featureValue", - ), - _schema=cls.get_schema(), - ) - - class FeatureFactory: """Factory class for creating features.""" diff --git a/tests/gentropy/dataset/test_l2g_feature.py b/tests/gentropy/dataset/test_l2g_feature.py new file mode 100644 index 000000000..94a96e66c --- /dev/null +++ b/tests/gentropy/dataset/test_l2g_feature.py @@ -0,0 +1,57 @@ +"""Test L2G feature generation.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +import pytest + +from gentropy.dataset.l2g_feature import ( + EQtlColocClppMaximumFeature, + EQtlColocH4MaximumFeature, + L2GFeature, + PQtlColocClppMaximumFeature, + PQtlColocH4MaximumFeature, + SQtlColocClppMaximumFeature, + SQtlColocH4MaximumFeature, + TuQtlColocClppMaximumFeature, + TuQtlColocH4MaximumFeature, +) +from gentropy.method.l2g.feature_factory import L2GFeatureInputLoader + +if TYPE_CHECKING: + from gentropy.dataset.colocalisation import Colocalisation + from gentropy.dataset.study_index import StudyIndex + from gentropy.dataset.study_locus import StudyLocus + + +@pytest.mark.parametrize( + "feature_class", + [ + EQtlColocH4MaximumFeature, + PQtlColocH4MaximumFeature, + SQtlColocH4MaximumFeature, + TuQtlColocH4MaximumFeature, + EQtlColocClppMaximumFeature, + PQtlColocClppMaximumFeature, + SQtlColocClppMaximumFeature, + TuQtlColocClppMaximumFeature, + ], +) +def test_feature_factory_return_type( + feature_class: Any, + mock_study_locus: StudyLocus, + mock_colocalisation: Colocalisation, + mock_study_index: StudyIndex, +) -> None: + """Test that every feature factory returns a L2GFeature dataset.""" + loader = L2GFeatureInputLoader( + colocalisation=mock_colocalisation, study_index=mock_study_index + ) + feature_dataset = feature_class.compute( + study_loci_to_annotate=mock_study_locus, + feature_dependency=loader.get_dependency_by_type( + feature_class.feature_dependency_type + ), + ) + assert isinstance(feature_dataset, L2GFeature) diff --git a/tests/gentropy/method/test_l2g/test_feature_factory.py b/tests/gentropy/method/test_l2g/test_feature_factory.py deleted file mode 100644 index b2939a902..000000000 --- a/tests/gentropy/method/test_l2g/test_feature_factory.py +++ /dev/null @@ -1,193 +0,0 @@ -"""Test locus-to-gene model training.""" - -from __future__ import annotations - -from typing import TYPE_CHECKING, Any - -import pytest - -from gentropy.dataset.l2g_feature import L2GFeature -from gentropy.method.l2g.feature_factory import ( - EQtlColocClppMaximumFeature, - EQtlColocH4MaximumFeature, - L2GFeatureInputLoader, - PQtlColocClppMaximumFeature, - PQtlColocH4MaximumFeature, - SQtlColocClppMaximumFeature, - SQtlColocH4MaximumFeature, - TuQtlColocClppMaximumFeature, - TuQtlColocH4MaximumFeature, -) - -if TYPE_CHECKING: - from gentropy.dataset.colocalisation import Colocalisation - from gentropy.dataset.study_index import StudyIndex - from gentropy.dataset.study_locus import StudyLocus - - -# @pytest.fixture(scope="module") -# def model() -> LocusToGeneModel: -# """Creates an instance of the LocusToGene class.""" -# return LocusToGeneModel(model=RandomForestClassifier()) - - -@pytest.mark.parametrize( - "feature_class", - [ - EQtlColocH4MaximumFeature, - PQtlColocH4MaximumFeature, - SQtlColocH4MaximumFeature, - TuQtlColocH4MaximumFeature, - EQtlColocClppMaximumFeature, - PQtlColocClppMaximumFeature, - SQtlColocClppMaximumFeature, - TuQtlColocClppMaximumFeature, - ], -) -def test_colocalisation_feature_type( - feature_class: Any, - mock_study_locus: StudyLocus, - mock_colocalisation: Colocalisation, - mock_study_index: StudyIndex, -) -> None: - """Test that every colocalisation feature type returns a set of L2GFeatures.""" - loader = L2GFeatureInputLoader( - colocalisation=mock_colocalisation, study_index=mock_study_index - ) - feature_dataset = feature_class.compute( - study_loci_to_annotate=mock_study_locus, - feature_dependency=loader.get_dependency_by_type( - feature_class.feature_dependency_type - ), - ) - assert isinstance(feature_dataset, L2GFeature) - - -# class TestColocalisationFactory: -# """Test the ColocalisationFactory methods.""" - -# def test_get_max_coloc_per_credible_set( -# self: TestColocalisationFactory, -# mock_study_locus: StudyLocus, -# mock_study_index: StudyIndex, -# mock_colocalisation: Colocalisation, -# ) -> None: -# """Test the function that extracts the maximum log likelihood ratio for each pair of overlapping study-locus returns the right data type.""" -# coloc_features = ColocalisationFactory._get_max_coloc_per_credible_set( -# mock_colocalisation, -# mock_study_locus, -# mock_study_index, -# ) -# assert isinstance( -# coloc_features, L2GFeature -# ), "Unexpected type returned from _get_max_coloc_per_credible_set" - -# def test_get_max_coloc_per_credible_set_semantic( -# self: TestColocalisationFactory, -# spark: SparkSession, -# ) -> None: -# """Test logic of the function that extracts the maximum log likelihood ratio for each pair of overlapping study-locus.""" -# # Prepare mock datasets based on 2 associations -# credset = StudyLocus( -# _df=spark.createDataFrame( -# # 2 associations with a common variant in the locus -# [ -# { -# "studyLocusId": 1, -# "variantId": "lead1", -# "studyId": "study1", # this is a GWAS -# "locus": [ -# {"variantId": "commonTag", "posteriorProbability": 0.9}, -# ], -# "chromosome": "1", -# }, -# { -# "studyLocusId": 2, -# "variantId": "lead2", -# "studyId": "study2", # this is a eQTL study -# "locus": [ -# {"variantId": "commonTag", "posteriorProbability": 0.9}, -# ], -# "chromosome": "1", -# }, -# ], -# StudyLocus.get_schema(), -# ), -# _schema=StudyLocus.get_schema(), -# ) - -# studies = StudyIndex( -# _df=spark.createDataFrame( -# [ -# { -# "studyId": "study1", -# "studyType": "gwas", -# "traitFromSource": "trait1", -# "projectId": "project1", -# }, -# { -# "studyId": "study2", -# "studyType": "eqtl", -# "geneId": "gene1", -# "traitFromSource": "trait2", -# "projectId": "project2", -# }, -# ] -# ), -# _schema=StudyIndex.get_schema(), -# ) -# coloc = Colocalisation( -# _df=spark.createDataFrame( -# [ -# { -# "leftStudyLocusId": 1, -# "rightStudyLocusId": 2, -# "chromosome": "1", -# "colocalisationMethod": "eCAVIAR", -# "numberColocalisingVariants": 1, -# "clpp": 0.81, # 0.9*0.9 -# "log2h4h3": None, -# } -# ], -# schema=Colocalisation.get_schema(), -# ), -# _schema=Colocalisation.get_schema(), -# ) -# expected_coloc_features_df = spark.createDataFrame( -# [ -# (1, "gene1", "eqtlColocClppMaximum", 0.81), -# (1, "gene1", "eqtlColocClppMaximumNeighborhood", -4.0), -# ], -# L2GFeature.get_schema(), -# ) -# # Test -# coloc_features = ColocalisationFactory._get_max_coloc_per_credible_set( -# coloc, -# credset, -# studies, -# ) -# assert coloc_features.df.collect() == expected_coloc_features_df.collect() - - -# class TestStudyLocusFactory: -# """Test the StudyLocusFactory methods.""" - -# def test_get_tss_distance_features( -# self: TestStudyLocusFactory, mock_study_locus: StudyLocus, mock_v2g: V2G -# ) -> None: -# """Test the function that extracts the distance to the TSS.""" -# tss_distance = StudyLocusFactory._get_tss_distance_features( -# mock_study_locus, mock_v2g -# ) -# assert isinstance( -# tss_distance, L2GFeature -# ), "Unexpected model type returned from _get_tss_distance_features" - -# def test_get_vep_features( -# self: TestStudyLocusFactory, mock_study_locus: StudyLocus, mock_v2g: V2G -# ) -> None: -# """Test the function that extracts the VEP features.""" -# vep_features = StudyLocusFactory._get_vep_features(mock_study_locus, mock_v2g) -# assert isinstance( -# vep_features, L2GFeature -# ), "Unexpected model type returned from _get_vep_features" From d3498b4d1261a0730139f23857f8e9ae5d1851ba Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Irene=20L=C3=B3pez?= Date: Wed, 18 Sep 2024 17:39:43 +0100 Subject: [PATCH 36/48] docs: update feature docs --- docs/python_api/datasets/l2g_feature.md | 22 ++++++++++++++++++- .../python_api/methods/l2g/feature_factory.md | 4 ++++ 2 files changed, 25 insertions(+), 1 deletion(-) diff --git a/docs/python_api/datasets/l2g_feature.md b/docs/python_api/datasets/l2g_feature.md index cf8c3dcf1..bdab67e7c 100644 --- a/docs/python_api/datasets/l2g_feature.md +++ b/docs/python_api/datasets/l2g_feature.md @@ -2,7 +2,27 @@ title: L2G Feature --- -::: gentropy.method.l2g.feature_factory.L2GFeature +## Abstract Class + +::: gentropy.dataset.l2g_feature.L2GFeature + +## Feature Classes + +### Derived from colocalisation + +::: gentropy.dataset.l2g_feature.EQtlColocClppMaximumFeature +::: gentropy.dataset.l2g_feature.PQtlColocClppMaximumFeature +::: gentropy.dataset.l2g_feature.SQtlColocClppMaximumFeature +::: gentropy.dataset.l2g_feature.TuQtlColocClppMaximumFeature +::: gentropy.dataset.l2g_feature.EQtlColocH4MaximumFeature +::: gentropy.dataset.l2g_feature.PQtlColocH4MaximumFeature +::: gentropy.dataset.l2g_feature.SQtlColocH4MaximumFeature +::: gentropy.dataset.l2g_feature.TuQtlColocH4MaximumFeature + +### Derived from distance + +::: gentropy.dataset.l2g_feature.DistanceTssMinimumFeature +::: gentropy.dataset.l2g_feature.DistanceTssMeanFeature ## Schema diff --git a/docs/python_api/methods/l2g/feature_factory.md b/docs/python_api/methods/l2g/feature_factory.md index 244796254..ec812d2da 100644 --- a/docs/python_api/methods/l2g/feature_factory.md +++ b/docs/python_api/methods/l2g/feature_factory.md @@ -1,3 +1,7 @@ --- title: L2G Feature Factory --- + +::: gentropy.method.l2g.feature_factory.FeatureFactory + +::: gentropy.method.l2g.feature_factory.L2GFeatureInputLoader From 03b11e2549cb2ab4ef7fb3285f23ae4ed02b7883 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Irene=20L=C3=B3pez?= Date: Wed, 18 Sep 2024 17:59:21 +0100 Subject: [PATCH 37/48] fix: import --- src/gentropy/dataset/l2g_feature.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/gentropy/dataset/l2g_feature.py b/src/gentropy/dataset/l2g_feature.py index 20bd5ad34..c2ba73497 100644 --- a/src/gentropy/dataset/l2g_feature.py +++ b/src/gentropy/dataset/l2g_feature.py @@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, Any import pyspark.sql.functions as f +from pyspark.sql import Window from gentropy.common.schemas import parse_spark_schema from gentropy.common.spark_helpers import convert_from_wide_to_long @@ -16,7 +17,7 @@ from gentropy.dataset.variant_index import VariantIndex if TYPE_CHECKING: - from pyspark.sql import Column, DataFrame, Window + from pyspark.sql import Column, DataFrame from pyspark.sql.types import StructType from gentropy.dataset.l2g_gold_standard import L2GGoldStandard From 87d187785c138fb9555ad0c3c80f4c758eaac4ee Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Irene=20L=C3=B3pez?= Date: Tue, 24 Sep 2024 18:40:30 +0100 Subject: [PATCH 38/48] test: add semantic `TestCommonColocalisationFeatureLogic` --- src/gentropy/dataset/l2g_feature.py | 4 +- tests/gentropy/dataset/test_l2g_feature.py | 205 ++++++++++++++++++++- 2 files changed, 204 insertions(+), 5 deletions(-) diff --git a/src/gentropy/dataset/l2g_feature.py b/src/gentropy/dataset/l2g_feature.py index cec83ccec..d9885ca31 100644 --- a/src/gentropy/dataset/l2g_feature.py +++ b/src/gentropy/dataset/l2g_feature.py @@ -170,7 +170,7 @@ def _common_neighbourhood_colocalisation_feature_logic( f.max(local_feature_name).over(Window.partitionBy("studyLocusId")), ) .withColumn(feature_name, f.col("regional_maximum") - f.col(local_feature_name)) - .drop("regional_maximum") + .drop("regional_maximum", local_feature_name) ) @@ -866,7 +866,7 @@ def _common_distance_feature_logic( agg_expr (Column): The expression that aggregate distances into a specific way to define the feature Returns: - DataFrame: Feature dataset + DataFrame: Feature dataset """ distances_dataset = variant_index.get_distance_to_gene(distance_type=distance_type) return ( diff --git a/tests/gentropy/dataset/test_l2g_feature.py b/tests/gentropy/dataset/test_l2g_feature.py index 91520a445..995db22cf 100644 --- a/tests/gentropy/dataset/test_l2g_feature.py +++ b/tests/gentropy/dataset/test_l2g_feature.py @@ -6,6 +6,7 @@ import pytest +from gentropy.dataset.colocalisation import Colocalisation from gentropy.dataset.l2g_feature import ( DistanceFootprintMeanFeature, DistanceFootprintMeanNeighbourhoodFeature, @@ -32,13 +33,16 @@ TuQtlColocClppMaximumNeighbourhoodFeature, TuQtlColocH4MaximumFeature, TuQtlColocH4MaximumNeighbourhoodFeature, + _common_colocalisation_feature_logic, + _common_neighbourhood_colocalisation_feature_logic, ) +from gentropy.dataset.study_index import StudyIndex +from gentropy.dataset.study_locus import StudyLocus from gentropy.method.l2g.feature_factory import L2GFeatureInputLoader if TYPE_CHECKING: - from gentropy.dataset.colocalisation import Colocalisation - from gentropy.dataset.study_index import StudyIndex - from gentropy.dataset.study_locus import StudyLocus + from pyspark.sql import SparkSession + from gentropy.dataset.variant_index import VariantIndex @@ -83,6 +87,7 @@ def test_feature_factory_return_type( colocalisation=mock_colocalisation, study_index=mock_study_index, variant_index=mock_variant_index, + study_locus=mock_study_locus, ) feature_dataset = feature_class.compute( study_loci_to_annotate=mock_study_locus, @@ -93,6 +98,200 @@ def test_feature_factory_return_type( assert isinstance(feature_dataset, L2GFeature) +class TestCommonColocalisationFeatureLogic: + """Test the common logic of the colocalisation features.""" + + def test__common_colocalisation_feature_logic( + self: TestCommonColocalisationFeatureLogic, + spark: SparkSession, + ) -> None: + """Test the common logic of the colocalisation features. + + The test data associates studyLocusId1 with gene1 based on the colocalisation with studyLocusId2 and studyLocusId3. + The H4 value of number 2 is higher, therefore the feature value should be based on that. + """ + feature_name = "eQtlColocH4Maximum" + observed_df = _common_colocalisation_feature_logic( + self.sample_study_loci_to_annotate, + self.colocalisation_method, + self.colocalisation_metric, + feature_name, + self.qtl_type, + colocalisation=self.sample_colocalisation, + study_index=self.sample_studies, + study_locus=self.sample_study_locus, + ) + expected_df = spark.createDataFrame( + [ + { + "studyLocusId": 1, + "geneId": "gene1", + "eQtlColocH4Maximum": 0.81, + }, + { + "studyLocusId": 1, + "geneId": "gene2", + "eQtlColocH4Maximum": 0.9, + }, + ], + ).select("studyLocusId", "geneId", "eQtlColocH4Maximum") + assert ( + observed_df.collect() == expected_df.collect() + ), "The feature values are not as expected." + + def test__common_neighbourhood_colocalisation_feature_logic( + self: TestCommonColocalisationFeatureLogic, spark: SparkSession + ) -> None: + """Test the common logic of the neighbourhood colocalisation features.""" + feature_name = "eQtlColocH4MaximumNeighbourhood" + observed_df = _common_neighbourhood_colocalisation_feature_logic( + self.sample_study_loci_to_annotate, + self.colocalisation_method, + self.colocalisation_metric, + feature_name, + self.qtl_type, + colocalisation=self.sample_colocalisation, + study_index=self.sample_studies, + study_locus=self.sample_study_locus, + ) + expected_df = spark.createDataFrame( + [ + { + "studyLocusId": 1, + "geneId": "gene1", + "eQtlColocH4MaximumNeighbourhood": 0.08999999999999997, + }, + { + "studyLocusId": 1, + "geneId": "gene2", + "eQtlColocH4MaximumNeighbourhood": 0.0, + }, + ], + ).select("studyLocusId", "geneId", "eQtlColocH4MaximumNeighbourhood") + assert ( + observed_df.collect() == expected_df.collect() + ), "The expected and observed dataframes do not match." + + @pytest.fixture(autouse=True) + def _setup(self: TestCommonColocalisationFeatureLogic, spark: SparkSession) -> None: + """Set up the test variables.""" + self.colocalisation_method = "Coloc" + self.colocalisation_metric = "h4" + self.qtl_type = "eqtl" + + self.sample_study_loci_to_annotate = StudyLocus( + _df=spark.createDataFrame( + [ + { + "studyLocusId": 1, + "variantId": "lead1", + "studyId": "study1", # this is a GWAS + "chromosome": "1", + }, + ] + ), + _schema=StudyLocus.get_schema(), + ) + self.sample_colocalisation = Colocalisation( + _df=spark.createDataFrame( + [ + { + "leftStudyLocusId": 1, + "rightStudyLocusId": 2, + "chromosome": "1", + "colocalisationMethod": "COLOC", + "numberColocalisingVariants": 1, + "h4": 0.81, + }, + { + "leftStudyLocusId": 1, + "rightStudyLocusId": 3, # qtl linked to the same gene as studyLocusId 2 with a lower score + "chromosome": "1", + "colocalisationMethod": "COLOC", + "numberColocalisingVariants": 1, + "h4": 0.50, + }, + { + "leftStudyLocusId": 1, + "rightStudyLocusId": 4, # qtl linked to a diff gene and with the highest score + "chromosome": "1", + "colocalisationMethod": "COLOC", + "numberColocalisingVariants": 1, + "h4": 0.90, + }, + ], + schema=Colocalisation.get_schema(), + ), + _schema=Colocalisation.get_schema(), + ) + self.sample_study_locus = StudyLocus( + _df=spark.createDataFrame( + [ + { + "studyLocusId": 1, + "variantId": "lead1", + "studyId": "study1", # this is a GWAS + "chromosome": "1", + }, + { + "studyLocusId": 2, + "variantId": "lead1", + "studyId": "study2", # this is a QTL (same gee) + "chromosome": "1", + }, + { + "studyLocusId": 3, + "variantId": "lead1", + "studyId": "study3", # this is another QTL (same gene) + "chromosome": "1", + }, + { + "studyLocusId": 4, + "variantId": "lead1", + "studyId": "study4", # this is another QTL (diff gene) + "chromosome": "1", + }, + ] + ), + _schema=StudyLocus.get_schema(), + ) + self.sample_studies = StudyIndex( + _df=spark.createDataFrame( + [ + { + "studyId": "study1", + "studyType": "gwas", + "geneId": None, + "traitFromSource": "trait1", + "projectId": "project1", + }, + { + "studyId": "study2", + "studyType": "eqtl", + "geneId": "gene1", + "traitFromSource": "trait2", + "projectId": "project2", + }, + { + "studyId": "study3", + "studyType": "eqtl", + "geneId": "gene1", + "traitFromSource": "trait3", + "projectId": "project3", + }, + { + "studyId": "study4", + "studyType": "eqtl", + "geneId": "gene2", + "traitFromSource": "trait4", + "projectId": "project4", + }, + ] + ), + _schema=StudyIndex.get_schema(), + ) + + # class TestColocalisationFactory: # """Test the ColocalisationFactory methods.""" # def test_get_max_coloc_per_credible_set_semantic( From dbc5d2e3a0d4f6625fcfea4d5437692b1fcfe87e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Irene=20L=C3=B3pez?= Date: Wed, 25 Sep 2024 10:43:23 +0100 Subject: [PATCH 39/48] test: add semantic `TestCommonDistanceFeatureLogic` --- tests/gentropy/dataset/test_l2g_feature.py | 256 ++++++++++++++------- 1 file changed, 168 insertions(+), 88 deletions(-) diff --git a/tests/gentropy/dataset/test_l2g_feature.py b/tests/gentropy/dataset/test_l2g_feature.py index 995db22cf..b77f0d103 100644 --- a/tests/gentropy/dataset/test_l2g_feature.py +++ b/tests/gentropy/dataset/test_l2g_feature.py @@ -4,7 +4,16 @@ from typing import TYPE_CHECKING, Any +import pyspark.sql.functions as f import pytest +from pyspark.sql.types import ( + ArrayType, + IntegerType, + LongType, + StringType, + StructField, + StructType, +) from gentropy.dataset.colocalisation import Colocalisation from gentropy.dataset.l2g_feature import ( @@ -34,17 +43,18 @@ TuQtlColocH4MaximumFeature, TuQtlColocH4MaximumNeighbourhoodFeature, _common_colocalisation_feature_logic, + _common_distance_feature_logic, _common_neighbourhood_colocalisation_feature_logic, + _common_neighbourhood_distance_feature_logic, ) from gentropy.dataset.study_index import StudyIndex from gentropy.dataset.study_locus import StudyLocus +from gentropy.dataset.variant_index import VariantIndex from gentropy.method.l2g.feature_factory import L2GFeatureInputLoader if TYPE_CHECKING: from pyspark.sql import SparkSession - from gentropy.dataset.variant_index import VariantIndex - @pytest.mark.parametrize( "feature_class", @@ -292,90 +302,160 @@ def _setup(self: TestCommonColocalisationFeatureLogic, spark: SparkSession) -> N ) -# class TestColocalisationFactory: -# """Test the ColocalisationFactory methods.""" -# def test_get_max_coloc_per_credible_set_semantic( -# self: TestColocalisationFactory, -# spark: SparkSession, -# ) -> None: -# """Test logic of the function that extracts the maximum log likelihood ratio for each pair of overlapping study-locus.""" -# # Prepare mock datasets based on 2 associations -# credset = StudyLocus( -# _df=spark.createDataFrame( -# # 2 associations with a common variant in the locus -# [ -# { -# "studyLocusId": 1, -# "variantId": "lead1", -# "studyId": "study1", # this is a GWAS -# "locus": [ -# {"variantId": "commonTag", "posteriorProbability": 0.9}, -# ], -# "chromosome": "1", -# }, -# { -# "studyLocusId": 2, -# "variantId": "lead2", -# "studyId": "study2", # this is a eQTL study -# "locus": [ -# {"variantId": "commonTag", "posteriorProbability": 0.9}, -# ], -# "chromosome": "1", -# }, -# ], -# StudyLocus.get_schema(), -# ), -# _schema=StudyLocus.get_schema(), -# ) +class TestCommonDistanceFeatureLogic: + """Test the CommonDistanceFeatureLogic methods.""" -# studies = StudyIndex( -# _df=spark.createDataFrame( -# [ -# { -# "studyId": "study1", -# "studyType": "gwas", -# "traitFromSource": "trait1", -# "projectId": "project1", -# }, -# { -# "studyId": "study2", -# "studyType": "eqtl", -# "geneId": "gene1", -# "traitFromSource": "trait2", -# "projectId": "project2", -# }, -# ] -# ), -# _schema=StudyIndex.get_schema(), -# ) -# coloc = Colocalisation( -# _df=spark.createDataFrame( -# [ -# { -# "leftStudyLocusId": 1, -# "rightStudyLocusId": 2, -# "chromosome": "1", -# "colocalisationMethod": "eCAVIAR", -# "numberColocalisingVariants": 1, -# "clpp": 0.81, # 0.9*0.9 -# "log2h4h3": None, -# } -# ], -# schema=Colocalisation.get_schema(), -# ), -# _schema=Colocalisation.get_schema(), -# ) -# expected_coloc_features_df = spark.createDataFrame( -# [ -# (1, "gene1", "eqtlColocClppMaximum", 0.81), -# (1, "gene1", "eqtlColocClppMaximumNeighborhood", -4.0), -# ], -# L2GFeature.get_schema(), -# ) -# # Test -# coloc_features = ColocalisationFactory._get_max_coloc_per_credible_set( -# coloc, -# credset, -# studies, -# ) -# assert coloc_features.df.collect() == expected_coloc_features_df.collect() + @pytest.mark.parametrize( + ("feature_name", "expected_distance"), + [ + ("distanceTssMinimum", 2.5), + ("distanceTssMean", 3.75), + ], + ) + def test__common_distance_feature_logic( + self: TestCommonDistanceFeatureLogic, + spark: SparkSession, + feature_name: str, + expected_distance: int, + ) -> None: + """Test the logic of the function that extracts the distance between the variants in a credible set and a gene.""" + agg_expr = ( + f.min(f.col("weightedDistance")) + if feature_name == "distanceTssMinimum" + else f.mean(f.col("weightedDistance")) + ) + observed_df = _common_distance_feature_logic( + self.sample_study_locus, + variant_index=self.sample_variant_index, + feature_name=feature_name, + distance_type=self.distance_type, + agg_expr=agg_expr, + ) + assert observed_df.first()[feature_name] == expected_distance + + def test__common_neighbourhood_colocalisation_feature_logic( + self: TestCommonDistanceFeatureLogic, + spark: SparkSession, + ) -> None: + """Test the logic of the function that extracts the distance between the variants in a credible set and the nearby genes.""" + another_sample_variant_index = VariantIndex( + _df=spark.createDataFrame( + [ + ( + "lead1", + "chrom", + 1, + "A", + "T", + [ + {"distanceFromTss": 10, "targetId": "gene1"}, + {"distanceFromTss": 100, "targetId": "gene2"}, + ], + ), + ( + "tag1", + "chrom", + 1, + "A", + "T", + [ + {"distanceFromTss": 5, "targetId": "gene1"}, + ], + ), + ], + self.variant_index_schema, + ), + _schema=VariantIndex.get_schema(), + ) + observed_df = _common_neighbourhood_distance_feature_logic( + self.sample_study_locus, + variant_index=another_sample_variant_index, + feature_name="distanceTssMinimum", + distance_type=self.distance_type, + agg_expr=f.min("weightedDistance"), + ).orderBy(f.col("distanceTssMinimum").asc()) + expected_df = spark.createDataFrame( + ([1, "gene2", -47.5], [1, "gene1", 0.0]), + ["studyLocusId", "geneId", "distanceTssMinimum"], + ).orderBy(f.col("distanceTssMinimum").asc()) + assert ( + observed_df.collect() == expected_df.collect() + ), "Output doesn't meet the expectation." + + @pytest.fixture(autouse=True) + def _setup(self: TestCommonDistanceFeatureLogic, spark: SparkSession) -> None: + """Set up testing fixtures.""" + self.distance_type = "distanceFromTss" + self.sample_study_locus = StudyLocus( + _df=spark.createDataFrame( + [ + { + "studyLocusId": 1, + "variantId": "lead1", + "studyId": "study1", + "locus": [ + { + "variantId": "lead1", + "posteriorProbability": 0.5, + }, + { + "variantId": "tag1", # this variant is closer to gene1 + "posteriorProbability": 0.5, + }, + ], + "chromosome": "1", + }, + ], + StudyLocus.get_schema(), + ), + _schema=StudyLocus.get_schema(), + ) + self.variant_index_schema = StructType( + [ + StructField("variantId", StringType(), True), + StructField("chromosome", StringType(), True), + StructField("position", IntegerType(), True), + StructField("referenceAllele", StringType(), True), + StructField("alternateAllele", StringType(), True), + StructField( + "transcriptConsequences", + ArrayType( + StructType( + [ + StructField("distanceFromTss", LongType(), True), + StructField("targetId", StringType(), True), + ] + ) + ), + True, + ), + ] + ) + self.sample_variant_index = VariantIndex( + _df=spark.createDataFrame( + [ + ( + "lead1", + "chrom", + 1, + "A", + "T", + [ + {"distanceFromTss": 10, "targetId": "gene1"}, + ], + ), + ( + "tag1", + "chrom", + 1, + "A", + "T", + [ + {"distanceFromTss": 5, "targetId": "gene1"}, + ], + ), + ], + self.variant_index_schema, + ), + _schema=VariantIndex.get_schema(), + ) From 69d71125bf45d2a12cfcfedef09b6f7f65f98926 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Irene=20L=C3=B3pez?= Date: Wed, 25 Sep 2024 12:04:57 +0100 Subject: [PATCH 40/48] refactor: separate features into diff modules --- docs/python_api/datasets/l2g_feature.md | 29 -- .../datasets/l2g_features/_l2g_feature.md | 11 + .../datasets/l2g_features/colocalisation.md | 27 + .../datasets/l2g_features/distance.md | 19 + src/gentropy/dataset/l2g_features/__init__.py | 3 + .../colocalisation.py} | 479 +----------------- src/gentropy/dataset/l2g_features/distance.py | 431 ++++++++++++++++ .../dataset/l2g_features/l2g_feature.py | 65 +++ src/gentropy/method/l2g/feature_factory.py | 22 +- tests/gentropy/dataset/test_l2g_feature.py | 24 +- 10 files changed, 585 insertions(+), 525 deletions(-) delete mode 100644 docs/python_api/datasets/l2g_feature.md create mode 100644 docs/python_api/datasets/l2g_features/_l2g_feature.md create mode 100644 docs/python_api/datasets/l2g_features/colocalisation.md create mode 100644 docs/python_api/datasets/l2g_features/distance.md create mode 100644 src/gentropy/dataset/l2g_features/__init__.py rename src/gentropy/dataset/{l2g_feature.py => l2g_features/colocalisation.py} (62%) create mode 100644 src/gentropy/dataset/l2g_features/distance.py create mode 100644 src/gentropy/dataset/l2g_features/l2g_feature.py diff --git a/docs/python_api/datasets/l2g_feature.md b/docs/python_api/datasets/l2g_feature.md deleted file mode 100644 index bdab67e7c..000000000 --- a/docs/python_api/datasets/l2g_feature.md +++ /dev/null @@ -1,29 +0,0 @@ ---- -title: L2G Feature ---- - -## Abstract Class - -::: gentropy.dataset.l2g_feature.L2GFeature - -## Feature Classes - -### Derived from colocalisation - -::: gentropy.dataset.l2g_feature.EQtlColocClppMaximumFeature -::: gentropy.dataset.l2g_feature.PQtlColocClppMaximumFeature -::: gentropy.dataset.l2g_feature.SQtlColocClppMaximumFeature -::: gentropy.dataset.l2g_feature.TuQtlColocClppMaximumFeature -::: gentropy.dataset.l2g_feature.EQtlColocH4MaximumFeature -::: gentropy.dataset.l2g_feature.PQtlColocH4MaximumFeature -::: gentropy.dataset.l2g_feature.SQtlColocH4MaximumFeature -::: gentropy.dataset.l2g_feature.TuQtlColocH4MaximumFeature - -### Derived from distance - -::: gentropy.dataset.l2g_feature.DistanceTssMinimumFeature -::: gentropy.dataset.l2g_feature.DistanceTssMeanFeature - -## Schema - ---8<-- "assets/schemas/l2g_feature.md" diff --git a/docs/python_api/datasets/l2g_features/_l2g_feature.md b/docs/python_api/datasets/l2g_features/_l2g_feature.md new file mode 100644 index 000000000..8ac17d530 --- /dev/null +++ b/docs/python_api/datasets/l2g_features/_l2g_feature.md @@ -0,0 +1,11 @@ +--- +title: L2G Feature +--- + +## Abstract Class + +::: gentropy.dataset.l2g_features.L2GFeature + +## Schema + +--8<-- "assets/schemas/l2g_feature.md" diff --git a/docs/python_api/datasets/l2g_features/colocalisation.md b/docs/python_api/datasets/l2g_features/colocalisation.md new file mode 100644 index 000000000..a4b2a5506 --- /dev/null +++ b/docs/python_api/datasets/l2g_features/colocalisation.md @@ -0,0 +1,27 @@ +--- +title: Colocalisation derived features +--- + +## List of features + +::: gentropy.dataset.l2g_features.EQtlColocClppMaximumFeature +::: gentropy.dataset.l2g_features.PQtlColocClppMaximumFeature +::: gentropy.dataset.l2g_features.SQtlColocClppMaximumFeature +::: gentropy.dataset.l2g_features.TuQtlColocClppMaximumFeature +::: gentropy.dataset.l2g_features.EQtlColocH4MaximumFeature +::: gentropy.dataset.l2g_features.PQtlColocH4MaximumFeature +::: gentropy.dataset.l2g_features.SQtlColocH4MaximumFeature +::: gentropy.dataset.l2g_features.TuQtlColocH4MaximumFeature +::: gentropy.dataset.l2g_features.EQtlColocClppMaximumNeighbourhoodFeature +::: gentropy.dataset.l2g_features.PQtlColocClppMaximumNeighbourhoodFeature +::: gentropy.dataset.l2g_features.SQtlColocClppMaximumNeighbourhoodFeature +::: gentropy.dataset.l2g_features.TuQtlColocClppMaximumNeighbourhoodFeature +::: gentropy.dataset.l2g_features.EQtlColocH4MaximumNeighbourhoodFeature +::: gentropy.dataset.l2g_features.PQtlColocH4MaximumNeighbourhoodFeature +::: gentropy.dataset.l2g_features.SQtlColocH4MaximumNeighbourhoodFeature +::: gentropy.dataset.l2g_features.TuQtlColocH4MaximumNeighbourhoodFeature + +## Common logic + +::: gentropy.dataset.l2g_features.\_common_colocalisation_feature_logic +::: gentropy.dataset.l2g_features.\_common_neighbourhood_colocalisation_feature_logic diff --git a/docs/python_api/datasets/l2g_features/distance.md b/docs/python_api/datasets/l2g_features/distance.md new file mode 100644 index 000000000..d41a39a01 --- /dev/null +++ b/docs/python_api/datasets/l2g_features/distance.md @@ -0,0 +1,19 @@ +--- +title: Distance derived features +--- + +## List of features + +::: gentropy.dataset.l2g_features.DistanceTssMinimumFeature +::: gentropy.dataset.l2g_features.DistanceTssMeanFeature +::: gentropy.dataset.l2g_features.DistanceFootprintMinimumFeature +::: gentropy.dataset.l2g_features.DistanceFootprintMeanFeature +::: gentropy.dataset.l2g_features.DistanceTssMinimumNeighbourhoodFeature +::: gentropy.dataset.l2g_features.DistanceTssMeanNeighbourhoodFeature +::: gentropy.dataset.l2g_features.DistanceFootprintMinimumNeighbourhoodFeature +::: gentropy.dataset.l2g_features.DistanceFootprintMeanNeighbourhoodFeature + +## Common logic + +::: gentropy.dataset.l2g_features.\_common_distance_feature_logic +::: gentropy.dataset.l2g_features.\_common_neighbourhood_distance_feature_logic diff --git a/src/gentropy/dataset/l2g_features/__init__.py b/src/gentropy/dataset/l2g_features/__init__.py new file mode 100644 index 000000000..ce15cedfe --- /dev/null +++ b/src/gentropy/dataset/l2g_features/__init__.py @@ -0,0 +1,3 @@ +"""Feature factories for L2G.""" + +from __future__ import annotations diff --git a/src/gentropy/dataset/l2g_feature.py b/src/gentropy/dataset/l2g_features/colocalisation.py similarity index 62% rename from src/gentropy/dataset/l2g_feature.py rename to src/gentropy/dataset/l2g_features/colocalisation.py index d9885ca31..fe4fb065b 100644 --- a/src/gentropy/dataset/l2g_feature.py +++ b/src/gentropy/dataset/l2g_features/colocalisation.py @@ -1,78 +1,21 @@ -"""L2G Feature Dataset with a collection of methods that extract features from the gentropy datasets to be fed in L2G.""" +"""Collection of methods that extract features from the colocalisation datasets.""" from __future__ import annotations -from abc import ABC, abstractmethod -from dataclasses import dataclass from typing import TYPE_CHECKING, Any import pyspark.sql.functions as f from pyspark.sql import Window -from gentropy.common.schemas import parse_spark_schema from gentropy.common.spark_helpers import convert_from_wide_to_long from gentropy.dataset.colocalisation import Colocalisation -from gentropy.dataset.dataset import Dataset +from gentropy.dataset.l2g_features.l2g_feature import L2GFeature from gentropy.dataset.l2g_gold_standard import L2GGoldStandard from gentropy.dataset.study_index import StudyIndex from gentropy.dataset.study_locus import StudyLocus -from gentropy.dataset.variant_index import VariantIndex if TYPE_CHECKING: - from pyspark.sql import Column, DataFrame - from pyspark.sql.types import StructType - - from gentropy.dataset.l2g_gold_standard import L2GGoldStandard - from gentropy.dataset.study_locus import StudyLocus - - -@dataclass -class L2GFeature(Dataset, ABC): - """Locus-to-gene feature dataset.""" - - def __post_init__( - self: L2GFeature, - feature_dependency_type: Any = None, - credible_set: StudyLocus | None = None, - ) -> None: - """Initializes a L2GFeature dataset. Any child class of L2GFeature must implement the `compute` method. - - Args: - feature_dependency_type (Any): The dependency that the L2GFeature dataset depends on. Defaults to None. - credible_set (StudyLocus | None): The credible set that the L2GFeature dataset is based on. Defaults to None. - """ - super().__post_init__() - self.feature_dependency_type = feature_dependency_type - self.credible_set = credible_set - - @classmethod - def get_schema(cls: type[L2GFeature]) -> StructType: - """Provides the schema for the L2GFeature dataset. - - Returns: - StructType: Schema for the L2GFeature dataset - """ - return parse_spark_schema("l2g_feature.json") - - @classmethod - @abstractmethod - def compute( - cls: type[L2GFeature], - study_loci_to_annotate: StudyLocus | L2GGoldStandard, - feature_dependency: Any, - ) -> L2GFeature: - """Computes the L2GFeature dataset. - - Args: - study_loci_to_annotate (StudyLocus | L2GGoldStandard): The dataset containing study loci that will be used for annotation - feature_dependency (Any): The dependency that the L2GFeature class needs to compute the feature - Returns: - L2GFeature: a L2GFeature dataset - - Raises: - NotImplementedError: This method must be implemented in the child classes - """ - raise NotImplementedError("Must be implemented in the child classes") + from pyspark.sql import DataFrame def _common_colocalisation_feature_logic( @@ -725,7 +668,7 @@ def compute( class SQtlColocH4MaximumNeighbourhoodFeature(L2GFeature): """Max H4 for each (study, locus) aggregating over all sQTLs.""" - feature_dependency_type = [Colocalisation, StudyIndex] + feature_dependency_type = [Colocalisation, StudyIndex, StudyLocus] feature_name = "sQtlColocH4MaximumNeighbourhood" @classmethod @@ -846,417 +789,3 @@ def compute( ), _schema=cls.get_schema(), ) - - -def _common_distance_feature_logic( - study_loci_to_annotate: StudyLocus | L2GGoldStandard, - *, - variant_index: VariantIndex, - feature_name: str, - distance_type: str, - agg_expr: Column, -) -> DataFrame: - """Computes the feature. - - Args: - study_loci_to_annotate (StudyLocus | L2GGoldStandard): The dataset containing study loci that will be used for annotation - variant_index (VariantIndex): The dataset containing distance to gene information - feature_name (str): The name of the feature - distance_type (str): The type of distance to gene - agg_expr (Column): The expression that aggregate distances into a specific way to define the feature - - Returns: - DataFrame: Feature dataset - """ - distances_dataset = variant_index.get_distance_to_gene(distance_type=distance_type) - return ( - study_loci_to_annotate.df.withColumn("variantInLocus", f.explode_outer("locus")) - .select( - "studyLocusId", - f.col("variantInLocus.variantId").alias("variantInLocusId"), - f.col("variantInLocus.posteriorProbability").alias( - "variantInLocusPosteriorProbability" - ), - ) - .join( - distances_dataset.withColumnRenamed( - "variantId", "variantInLocusId" - ).withColumnRenamed("targetId", "geneId"), - on="variantInLocusId", - how="inner", - ) - .withColumn( - "weightedDistance", - f.col(distance_type) * f.col("variantInLocusPosteriorProbability"), - ) - .groupBy("studyLocusId", "geneId") - .agg(agg_expr.alias(feature_name)) - ) - - -def _common_neighbourhood_distance_feature_logic( - study_loci_to_annotate: StudyLocus | L2GGoldStandard, - *, - variant_index: VariantIndex, - feature_name: str, - distance_type: str, - agg_expr: Column, -) -> DataFrame: - """Calculate the neighbourhood distance feature. - - Args: - study_loci_to_annotate (StudyLocus | L2GGoldStandard): The dataset containing study loci that will be used for annotation - variant_index (VariantIndex): The dataset containing distance to gene information - feature_name (str): The name of the feature - distance_type (str): The type of distance to gene - agg_expr (Column): The expression that aggregate distances into a specific way to define the feature - - Returns: - DataFrame: Feature dataset - """ - local_feature_name = feature_name.replace("Neighbourhood", "") - # First compute mean distances to a gene - local_min = _common_distance_feature_logic( - study_loci_to_annotate, - feature_name=local_feature_name, - distance_type=distance_type, - agg_expr=agg_expr, - variant_index=variant_index, - ) - return ( - # Then compute minimum distance in the vicinity (feature will be the same for any gene associated with a studyLocus) - local_min.withColumn( - "regional_minimum", - f.min(local_feature_name).over(Window.partitionBy("studyLocusId")), - ) - .withColumn(feature_name, f.col("regional_minimum") - f.col(local_feature_name)) - .drop("regional_minimum") - ) - - -class DistanceTssMeanFeature(L2GFeature): - """Average distance of all tagging variants to gene TSS.""" - - fill_na_value = 500_000 - feature_dependency_type = VariantIndex - feature_name = "distanceTssMean" - - @classmethod - def compute( - cls: type[DistanceTssMeanFeature], - study_loci_to_annotate: StudyLocus | L2GGoldStandard, - feature_dependency: dict[str, Any], - ) -> DistanceTssMeanFeature: - """Computes the feature. - - Args: - study_loci_to_annotate (StudyLocus | L2GGoldStandard): The dataset containing study loci that will be used for annotation - feature_dependency (dict[str, Any]): Dataset that contains the distance information - - Returns: - DistanceTssMeanFeature: Feature dataset - """ - agg_expr = f.mean("weightedDistance") - distance_type = "distanceFromTss" - return cls( - _df=convert_from_wide_to_long( - _common_distance_feature_logic( - study_loci_to_annotate, - feature_name=cls.feature_name, - distance_type=distance_type, - agg_expr=agg_expr, - **feature_dependency, - ), - id_vars=("studyLocusId", "geneId"), - var_name="featureName", - value_name="featureValue", - ), - _schema=cls.get_schema(), - ) - - -class DistanceTssMeanNeighbourhoodFeature(L2GFeature): - """Minimum mean distance to TSS for all genes in the vicinity of a studyLocus.""" - - fill_na_value = 500_000 - feature_dependency_type = VariantIndex - feature_name = "distanceTssMeanNeighbourhood" - - @classmethod - def compute( - cls: type[DistanceTssMeanNeighbourhoodFeature], - study_loci_to_annotate: StudyLocus | L2GGoldStandard, - feature_dependency: dict[str, Any], - ) -> DistanceTssMeanNeighbourhoodFeature: - """Computes the feature. - - Args: - study_loci_to_annotate (StudyLocus | L2GGoldStandard): The dataset containing study loci that will be used for annotation - feature_dependency (dict[str, Any]): Dataset that contains the distance information - - Returns: - DistanceTssMeanNeighbourhoodFeature: Feature dataset - """ - agg_expr = f.mean("weightedDistance") - distance_type = "distanceFromTss" - return cls( - _df=convert_from_wide_to_long( - _common_neighbourhood_distance_feature_logic( - study_loci_to_annotate, - feature_name=cls.feature_name, - distance_type=distance_type, - agg_expr=agg_expr, - **feature_dependency, - ), - id_vars=("studyLocusId", "geneId"), - var_name="featureName", - value_name="featureValue", - ), - _schema=cls.get_schema(), - ) - - -class DistanceTssMinimumFeature(L2GFeature): - """Minimum distance of all tagging variants to gene TSS.""" - - fill_na_value = 500_000 - feature_dependency_type = VariantIndex - feature_name = "distanceTssMinimum" - - @classmethod - def compute( - cls: type[DistanceTssMinimumFeature], - study_loci_to_annotate: StudyLocus | L2GGoldStandard, - feature_dependency: dict[str, Any], - ) -> DistanceTssMinimumFeature: - """Computes the feature. - - Args: - study_loci_to_annotate (StudyLocus | L2GGoldStandard): The dataset containing study loci that will be used for annotation - feature_dependency (dict[str, Any]): Dataset that contains the distance information - - Returns: - DistanceTssMinimumFeature: Feature dataset - """ - agg_expr = f.mean("weightedDistance") - distance_type = "distanceFromTss" - return cls( - _df=convert_from_wide_to_long( - _common_distance_feature_logic( - study_loci_to_annotate, - feature_name=cls.feature_name, - distance_type=distance_type, - agg_expr=agg_expr, - **feature_dependency, - ), - id_vars=("studyLocusId", "geneId"), - var_name="featureName", - value_name="featureValue", - ), - _schema=cls.get_schema(), - ) - - -class DistanceTssMinimumNeighbourhoodFeature(L2GFeature): - """Minimum minimum distance to TSS for all genes in the vicinity of a studyLocus.""" - - fill_na_value = 500_000 - feature_dependency_type = VariantIndex - feature_name = "distanceTssMinimumNeighbourhood" - - @classmethod - def compute( - cls: type[DistanceTssMinimumNeighbourhoodFeature], - study_loci_to_annotate: StudyLocus | L2GGoldStandard, - feature_dependency: dict[str, Any], - ) -> DistanceTssMinimumNeighbourhoodFeature: - """Computes the feature. - - Args: - study_loci_to_annotate (StudyLocus | L2GGoldStandard): The dataset containing study loci that will be used for annotation - feature_dependency (dict[str, Any]): Dataset that contains the distance information - - Returns: - DistanceTssMinimumNeighbourhoodFeature: Feature dataset - """ - agg_expr = f.min("weightedDistance") - distance_type = "distanceFromTss" - return cls( - _df=convert_from_wide_to_long( - _common_neighbourhood_distance_feature_logic( - study_loci_to_annotate, - feature_name=cls.feature_name, - distance_type=distance_type, - agg_expr=agg_expr, - **feature_dependency, - ), - id_vars=("studyLocusId", "geneId"), - var_name="featureName", - value_name="featureValue", - ), - _schema=cls.get_schema(), - ) - - -class DistanceFootprintMeanFeature(L2GFeature): - """Average distance of all tagging variants to the footprint of a gene.""" - - fill_na_value = 500_000 - feature_dependency_type = VariantIndex - feature_name = "distanceFootprintMean" - - @classmethod - def compute( - cls: type[DistanceFootprintMeanFeature], - study_loci_to_annotate: StudyLocus | L2GGoldStandard, - feature_dependency: dict[str, Any], - ) -> DistanceFootprintMeanFeature: - """Computes the feature. - - Args: - study_loci_to_annotate (StudyLocus | L2GGoldStandard): The dataset containing study loci that will be used for annotation - feature_dependency (dict[str, Any]): Dataset that contains the distance information - - Returns: - DistanceFootprintMeanFeature: Feature dataset - """ - agg_expr = f.mean("weightedDistance") - distance_type = "distanceFromFootprint" - return cls( - _df=convert_from_wide_to_long( - _common_distance_feature_logic( - study_loci_to_annotate, - feature_name=cls.feature_name, - distance_type=distance_type, - agg_expr=agg_expr, - **feature_dependency, - ), - id_vars=("studyLocusId", "geneId"), - var_name="featureName", - value_name="featureValue", - ), - _schema=cls.get_schema(), - ) - - -class DistanceFootprintMeanNeighbourhoodFeature(L2GFeature): - """Minimum mean distance to footprint for all genes in the vicinity of a studyLocus.""" - - fill_na_value = 500_000 - feature_dependency_type = VariantIndex - feature_name = "distanceFootprintMeanNeighbourhood" - - @classmethod - def compute( - cls: type[DistanceFootprintMeanNeighbourhoodFeature], - study_loci_to_annotate: StudyLocus | L2GGoldStandard, - feature_dependency: dict[str, Any], - ) -> DistanceFootprintMeanNeighbourhoodFeature: - """Computes the feature. - - Args: - study_loci_to_annotate (StudyLocus | L2GGoldStandard): The dataset containing study loci that will be used for annotation - feature_dependency (dict[str, Any]): Dataset that contains the distance information - - Returns: - DistanceFootprintMeanNeighbourhoodFeature: Feature dataset - """ - agg_expr = f.mean("weightedDistance") - distance_type = "distanceFromFootprint" - return cls( - _df=convert_from_wide_to_long( - _common_neighbourhood_distance_feature_logic( - study_loci_to_annotate, - feature_name=cls.feature_name, - distance_type=distance_type, - agg_expr=agg_expr, - **feature_dependency, - ), - id_vars=("studyLocusId", "geneId"), - var_name="featureName", - value_name="featureValue", - ), - _schema=cls.get_schema(), - ) - - -class DistanceFootprintMinimumFeature(L2GFeature): - """Minimum distance of all tagging variants to the footprint of a gene.""" - - fill_na_value = 500_000 - feature_dependency_type = VariantIndex - feature_name = "DistanceFootprintMinimum" - - @classmethod - def compute( - cls: type[DistanceFootprintMinimumFeature], - study_loci_to_annotate: StudyLocus | L2GGoldStandard, - feature_dependency: dict[str, Any], - ) -> DistanceFootprintMinimumFeature: - """Computes the feature. - - Args: - study_loci_to_annotate (StudyLocus | L2GGoldStandard): The dataset containing study loci that will be used for annotation - feature_dependency (dict[str, Any]): Dataset that contains the distance information - - Returns: - DistanceFootprintMinimumFeature: Feature dataset - """ - agg_expr = f.mean("weightedDistance") - distance_type = "distanceFromFootprint" - return cls( - _df=convert_from_wide_to_long( - _common_distance_feature_logic( - study_loci_to_annotate, - feature_name=cls.feature_name, - distance_type=distance_type, - agg_expr=agg_expr, - **feature_dependency, - ), - id_vars=("studyLocusId", "geneId"), - var_name="featureName", - value_name="featureValue", - ), - _schema=cls.get_schema(), - ) - - -class DistanceFootprintMinimumNeighbourhoodFeature(L2GFeature): - """Minimum minimum distance to footprint for all genes in the vicinity of a studyLocus.""" - - fill_na_value = 500_000 - feature_dependency_type = VariantIndex - feature_name = "distanceFootprintMinimumNeighbourhood" - - @classmethod - def compute( - cls: type[DistanceFootprintMinimumNeighbourhoodFeature], - study_loci_to_annotate: StudyLocus | L2GGoldStandard, - feature_dependency: dict[str, Any], - ) -> DistanceFootprintMinimumNeighbourhoodFeature: - """Computes the feature. - - Args: - study_loci_to_annotate (StudyLocus | L2GGoldStandard): The dataset containing study loci that will be used for annotation - feature_dependency (dict[str, Any]): Dataset that contains the distance information - - Returns: - DistanceFootprintMinimumNeighbourhoodFeature: Feature dataset - """ - agg_expr = f.min("weightedDistance") - distance_type = "distanceFromFootprint" - return cls( - _df=convert_from_wide_to_long( - _common_neighbourhood_distance_feature_logic( - study_loci_to_annotate, - feature_name=cls.feature_name, - distance_type=distance_type, - agg_expr=agg_expr, - **feature_dependency, - ), - id_vars=("studyLocusId", "geneId"), - var_name="featureName", - value_name="featureValue", - ), - _schema=cls.get_schema(), - ) diff --git a/src/gentropy/dataset/l2g_features/distance.py b/src/gentropy/dataset/l2g_features/distance.py new file mode 100644 index 000000000..8773732f9 --- /dev/null +++ b/src/gentropy/dataset/l2g_features/distance.py @@ -0,0 +1,431 @@ +"""Collection of methods that extract distance features from the variant index dataset.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +import pyspark.sql.functions as f +from pyspark.sql import Window + +from gentropy.common.spark_helpers import convert_from_wide_to_long +from gentropy.dataset.l2g_features.l2g_feature import L2GFeature +from gentropy.dataset.l2g_gold_standard import L2GGoldStandard +from gentropy.dataset.study_locus import StudyLocus +from gentropy.dataset.variant_index import VariantIndex + +if TYPE_CHECKING: + from pyspark.sql import Column, DataFrame + + +def _common_distance_feature_logic( + study_loci_to_annotate: StudyLocus | L2GGoldStandard, + *, + variant_index: VariantIndex, + feature_name: str, + distance_type: str, + agg_expr: Column, +) -> DataFrame: + """Computes the feature. + + Args: + study_loci_to_annotate (StudyLocus | L2GGoldStandard): The dataset containing study loci that will be used for annotation + variant_index (VariantIndex): The dataset containing distance to gene information + feature_name (str): The name of the feature + distance_type (str): The type of distance to gene + agg_expr (Column): The expression that aggregate distances into a specific way to define the feature + + Returns: + DataFrame: Feature dataset + """ + distances_dataset = variant_index.get_distance_to_gene(distance_type=distance_type) + return ( + study_loci_to_annotate.df.withColumn("variantInLocus", f.explode_outer("locus")) + .select( + "studyLocusId", + f.col("variantInLocus.variantId").alias("variantInLocusId"), + f.col("variantInLocus.posteriorProbability").alias( + "variantInLocusPosteriorProbability" + ), + ) + .join( + distances_dataset.withColumnRenamed( + "variantId", "variantInLocusId" + ).withColumnRenamed("targetId", "geneId"), + on="variantInLocusId", + how="inner", + ) + .withColumn( + "weightedDistance", + f.col(distance_type) * f.col("variantInLocusPosteriorProbability"), + ) + .groupBy("studyLocusId", "geneId") + .agg(agg_expr.alias(feature_name)) + ) + + +def _common_neighbourhood_distance_feature_logic( + study_loci_to_annotate: StudyLocus | L2GGoldStandard, + *, + variant_index: VariantIndex, + feature_name: str, + distance_type: str, + agg_expr: Column, +) -> DataFrame: + """Calculate the neighbourhood distance feature. + + Args: + study_loci_to_annotate (StudyLocus | L2GGoldStandard): The dataset containing study loci that will be used for annotation + variant_index (VariantIndex): The dataset containing distance to gene information + feature_name (str): The name of the feature + distance_type (str): The type of distance to gene + agg_expr (Column): The expression that aggregate distances into a specific way to define the feature + + Returns: + DataFrame: Feature dataset + """ + local_feature_name = feature_name.replace("Neighbourhood", "") + # First compute mean distances to a gene + local_min = _common_distance_feature_logic( + study_loci_to_annotate, + feature_name=local_feature_name, + distance_type=distance_type, + agg_expr=agg_expr, + variant_index=variant_index, + ) + return ( + # Then compute minimum distance in the vicinity (feature will be the same for any gene associated with a studyLocus) + local_min.withColumn( + "regional_minimum", + f.min(local_feature_name).over(Window.partitionBy("studyLocusId")), + ) + .withColumn(feature_name, f.col("regional_minimum") - f.col(local_feature_name)) + .drop("regional_minimum") + ) + + +class DistanceTssMeanFeature(L2GFeature): + """Average distance of all tagging variants to gene TSS.""" + + fill_na_value = 500_000 + feature_dependency_type = VariantIndex + feature_name = "distanceTssMean" + + @classmethod + def compute( + cls: type[DistanceTssMeanFeature], + study_loci_to_annotate: StudyLocus | L2GGoldStandard, + feature_dependency: dict[str, Any], + ) -> DistanceTssMeanFeature: + """Computes the feature. + + Args: + study_loci_to_annotate (StudyLocus | L2GGoldStandard): The dataset containing study loci that will be used for annotation + feature_dependency (dict[str, Any]): Dataset that contains the distance information + + Returns: + DistanceTssMeanFeature: Feature dataset + """ + agg_expr = f.mean("weightedDistance") + distance_type = "distanceFromTss" + return cls( + _df=convert_from_wide_to_long( + _common_distance_feature_logic( + study_loci_to_annotate, + feature_name=cls.feature_name, + distance_type=distance_type, + agg_expr=agg_expr, + **feature_dependency, + ), + id_vars=("studyLocusId", "geneId"), + var_name="featureName", + value_name="featureValue", + ), + _schema=cls.get_schema(), + ) + + +class DistanceTssMeanNeighbourhoodFeature(L2GFeature): + """Minimum mean distance to TSS for all genes in the vicinity of a studyLocus.""" + + fill_na_value = 500_000 + feature_dependency_type = VariantIndex + feature_name = "distanceTssMeanNeighbourhood" + + @classmethod + def compute( + cls: type[DistanceTssMeanNeighbourhoodFeature], + study_loci_to_annotate: StudyLocus | L2GGoldStandard, + feature_dependency: dict[str, Any], + ) -> DistanceTssMeanNeighbourhoodFeature: + """Computes the feature. + + Args: + study_loci_to_annotate (StudyLocus | L2GGoldStandard): The dataset containing study loci that will be used for annotation + feature_dependency (dict[str, Any]): Dataset that contains the distance information + + Returns: + DistanceTssMeanNeighbourhoodFeature: Feature dataset + """ + agg_expr = f.mean("weightedDistance") + distance_type = "distanceFromTss" + return cls( + _df=convert_from_wide_to_long( + _common_neighbourhood_distance_feature_logic( + study_loci_to_annotate, + feature_name=cls.feature_name, + distance_type=distance_type, + agg_expr=agg_expr, + **feature_dependency, + ), + id_vars=("studyLocusId", "geneId"), + var_name="featureName", + value_name="featureValue", + ), + _schema=cls.get_schema(), + ) + + +class DistanceTssMinimumFeature(L2GFeature): + """Minimum distance of all tagging variants to gene TSS.""" + + fill_na_value = 500_000 + feature_dependency_type = VariantIndex + feature_name = "distanceTssMinimum" + + @classmethod + def compute( + cls: type[DistanceTssMinimumFeature], + study_loci_to_annotate: StudyLocus | L2GGoldStandard, + feature_dependency: dict[str, Any], + ) -> DistanceTssMinimumFeature: + """Computes the feature. + + Args: + study_loci_to_annotate (StudyLocus | L2GGoldStandard): The dataset containing study loci that will be used for annotation + feature_dependency (dict[str, Any]): Dataset that contains the distance information + + Returns: + DistanceTssMinimumFeature: Feature dataset + """ + agg_expr = f.mean("weightedDistance") + distance_type = "distanceFromTss" + return cls( + _df=convert_from_wide_to_long( + _common_distance_feature_logic( + study_loci_to_annotate, + feature_name=cls.feature_name, + distance_type=distance_type, + agg_expr=agg_expr, + **feature_dependency, + ), + id_vars=("studyLocusId", "geneId"), + var_name="featureName", + value_name="featureValue", + ), + _schema=cls.get_schema(), + ) + + +class DistanceTssMinimumNeighbourhoodFeature(L2GFeature): + """Minimum minimum distance to TSS for all genes in the vicinity of a studyLocus.""" + + fill_na_value = 500_000 + feature_dependency_type = VariantIndex + feature_name = "distanceTssMinimumNeighbourhood" + + @classmethod + def compute( + cls: type[DistanceTssMinimumNeighbourhoodFeature], + study_loci_to_annotate: StudyLocus | L2GGoldStandard, + feature_dependency: dict[str, Any], + ) -> DistanceTssMinimumNeighbourhoodFeature: + """Computes the feature. + + Args: + study_loci_to_annotate (StudyLocus | L2GGoldStandard): The dataset containing study loci that will be used for annotation + feature_dependency (dict[str, Any]): Dataset that contains the distance information + + Returns: + DistanceTssMinimumNeighbourhoodFeature: Feature dataset + """ + agg_expr = f.min("weightedDistance") + distance_type = "distanceFromTss" + return cls( + _df=convert_from_wide_to_long( + _common_neighbourhood_distance_feature_logic( + study_loci_to_annotate, + feature_name=cls.feature_name, + distance_type=distance_type, + agg_expr=agg_expr, + **feature_dependency, + ), + id_vars=("studyLocusId", "geneId"), + var_name="featureName", + value_name="featureValue", + ), + _schema=cls.get_schema(), + ) + + +class DistanceFootprintMeanFeature(L2GFeature): + """Average distance of all tagging variants to the footprint of a gene.""" + + fill_na_value = 500_000 + feature_dependency_type = VariantIndex + feature_name = "distanceFootprintMean" + + @classmethod + def compute( + cls: type[DistanceFootprintMeanFeature], + study_loci_to_annotate: StudyLocus | L2GGoldStandard, + feature_dependency: dict[str, Any], + ) -> DistanceFootprintMeanFeature: + """Computes the feature. + + Args: + study_loci_to_annotate (StudyLocus | L2GGoldStandard): The dataset containing study loci that will be used for annotation + feature_dependency (dict[str, Any]): Dataset that contains the distance information + + Returns: + DistanceFootprintMeanFeature: Feature dataset + """ + agg_expr = f.mean("weightedDistance") + distance_type = "distanceFromFootprint" + return cls( + _df=convert_from_wide_to_long( + _common_distance_feature_logic( + study_loci_to_annotate, + feature_name=cls.feature_name, + distance_type=distance_type, + agg_expr=agg_expr, + **feature_dependency, + ), + id_vars=("studyLocusId", "geneId"), + var_name="featureName", + value_name="featureValue", + ), + _schema=cls.get_schema(), + ) + + +class DistanceFootprintMeanNeighbourhoodFeature(L2GFeature): + """Minimum mean distance to footprint for all genes in the vicinity of a studyLocus.""" + + fill_na_value = 500_000 + feature_dependency_type = VariantIndex + feature_name = "distanceFootprintMeanNeighbourhood" + + @classmethod + def compute( + cls: type[DistanceFootprintMeanNeighbourhoodFeature], + study_loci_to_annotate: StudyLocus | L2GGoldStandard, + feature_dependency: dict[str, Any], + ) -> DistanceFootprintMeanNeighbourhoodFeature: + """Computes the feature. + + Args: + study_loci_to_annotate (StudyLocus | L2GGoldStandard): The dataset containing study loci that will be used for annotation + feature_dependency (dict[str, Any]): Dataset that contains the distance information + + Returns: + DistanceFootprintMeanNeighbourhoodFeature: Feature dataset + """ + agg_expr = f.mean("weightedDistance") + distance_type = "distanceFromFootprint" + return cls( + _df=convert_from_wide_to_long( + _common_neighbourhood_distance_feature_logic( + study_loci_to_annotate, + feature_name=cls.feature_name, + distance_type=distance_type, + agg_expr=agg_expr, + **feature_dependency, + ), + id_vars=("studyLocusId", "geneId"), + var_name="featureName", + value_name="featureValue", + ), + _schema=cls.get_schema(), + ) + + +class DistanceFootprintMinimumFeature(L2GFeature): + """Minimum distance of all tagging variants to the footprint of a gene.""" + + fill_na_value = 500_000 + feature_dependency_type = VariantIndex + feature_name = "DistanceFootprintMinimum" + + @classmethod + def compute( + cls: type[DistanceFootprintMinimumFeature], + study_loci_to_annotate: StudyLocus | L2GGoldStandard, + feature_dependency: dict[str, Any], + ) -> DistanceFootprintMinimumFeature: + """Computes the feature. + + Args: + study_loci_to_annotate (StudyLocus | L2GGoldStandard): The dataset containing study loci that will be used for annotation + feature_dependency (dict[str, Any]): Dataset that contains the distance information + + Returns: + DistanceFootprintMinimumFeature: Feature dataset + """ + agg_expr = f.mean("weightedDistance") + distance_type = "distanceFromFootprint" + return cls( + _df=convert_from_wide_to_long( + _common_distance_feature_logic( + study_loci_to_annotate, + feature_name=cls.feature_name, + distance_type=distance_type, + agg_expr=agg_expr, + **feature_dependency, + ), + id_vars=("studyLocusId", "geneId"), + var_name="featureName", + value_name="featureValue", + ), + _schema=cls.get_schema(), + ) + + +class DistanceFootprintMinimumNeighbourhoodFeature(L2GFeature): + """Minimum minimum distance to footprint for all genes in the vicinity of a studyLocus.""" + + fill_na_value = 500_000 + feature_dependency_type = VariantIndex + feature_name = "distanceFootprintMinimumNeighbourhood" + + @classmethod + def compute( + cls: type[DistanceFootprintMinimumNeighbourhoodFeature], + study_loci_to_annotate: StudyLocus | L2GGoldStandard, + feature_dependency: dict[str, Any], + ) -> DistanceFootprintMinimumNeighbourhoodFeature: + """Computes the feature. + + Args: + study_loci_to_annotate (StudyLocus | L2GGoldStandard): The dataset containing study loci that will be used for annotation + feature_dependency (dict[str, Any]): Dataset that contains the distance information + + Returns: + DistanceFootprintMinimumNeighbourhoodFeature: Feature dataset + """ + agg_expr = f.min("weightedDistance") + distance_type = "distanceFromFootprint" + return cls( + _df=convert_from_wide_to_long( + _common_neighbourhood_distance_feature_logic( + study_loci_to_annotate, + feature_name=cls.feature_name, + distance_type=distance_type, + agg_expr=agg_expr, + **feature_dependency, + ), + id_vars=("studyLocusId", "geneId"), + var_name="featureName", + value_name="featureValue", + ), + _schema=cls.get_schema(), + ) diff --git a/src/gentropy/dataset/l2g_features/l2g_feature.py b/src/gentropy/dataset/l2g_features/l2g_feature.py new file mode 100644 index 000000000..7073ca758 --- /dev/null +++ b/src/gentropy/dataset/l2g_features/l2g_feature.py @@ -0,0 +1,65 @@ +"""L2G Feature Dataset with a collection of methods that extract features from the gentropy datasets to be fed in L2G.""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any + +from gentropy.common.schemas import parse_spark_schema +from gentropy.dataset.dataset import Dataset + +if TYPE_CHECKING: + from pyspark.sql.types import StructType + + from gentropy.dataset.l2g_gold_standard import L2GGoldStandard + from gentropy.dataset.study_locus import StudyLocus + + +@dataclass +class L2GFeature(Dataset, ABC): + """Locus-to-gene feature dataset that serves as template to generate each of the features that inform about locus to gene assignments.""" + + def __post_init__( + self: L2GFeature, + feature_dependency_type: Any = None, + credible_set: StudyLocus | None = None, + ) -> None: + """Initializes a L2GFeature dataset. Any child class of L2GFeature must implement the `compute` method. + + Args: + feature_dependency_type (Any): The dependency that the L2GFeature dataset depends on. Defaults to None. + credible_set (StudyLocus | None): The credible set that the L2GFeature dataset is based on. Defaults to None. + """ + super().__post_init__() + self.feature_dependency_type = feature_dependency_type + self.credible_set = credible_set + + @classmethod + def get_schema(cls: type[L2GFeature]) -> StructType: + """Provides the schema for the L2GFeature dataset. + + Returns: + StructType: Schema for the L2GFeature dataset + """ + return parse_spark_schema("l2g_feature.json") + + @classmethod + @abstractmethod + def compute( + cls: type[L2GFeature], + study_loci_to_annotate: StudyLocus | L2GGoldStandard, + feature_dependency: Any, + ) -> L2GFeature: + """Computes the L2GFeature dataset. + + Args: + study_loci_to_annotate (StudyLocus | L2GGoldStandard): The dataset containing study loci that will be used for annotation + feature_dependency (Any): The dependency that the L2GFeature class needs to compute the feature + Returns: + L2GFeature: a L2GFeature dataset + + Raises: + NotImplementedError: This method must be implemented in the child classes + """ + raise NotImplementedError("Must be implemented in the child classes") diff --git a/src/gentropy/method/l2g/feature_factory.py b/src/gentropy/method/l2g/feature_factory.py index b24772c6d..52d7ff76a 100644 --- a/src/gentropy/method/l2g/feature_factory.py +++ b/src/gentropy/method/l2g/feature_factory.py @@ -4,20 +4,11 @@ from typing import Any, Iterator, Mapping -from gentropy.dataset.l2g_feature import ( - DistanceFootprintMeanFeature, - DistanceFootprintMeanNeighbourhoodFeature, - DistanceFootprintMinimumFeature, - DistanceFootprintMinimumNeighbourhoodFeature, - DistanceTssMeanFeature, - DistanceTssMeanNeighbourhoodFeature, - DistanceTssMinimumFeature, - DistanceTssMinimumNeighbourhoodFeature, +from gentropy.dataset.l2g_features.colocalisation import ( EQtlColocClppMaximumFeature, EQtlColocClppMaximumNeighbourhoodFeature, EQtlColocH4MaximumFeature, EQtlColocH4MaximumNeighbourhoodFeature, - L2GFeature, PQtlColocClppMaximumFeature, PQtlColocClppMaximumNeighbourhoodFeature, PQtlColocH4MaximumFeature, @@ -31,6 +22,17 @@ TuQtlColocH4MaximumFeature, TuQtlColocH4MaximumNeighbourhoodFeature, ) +from gentropy.dataset.l2g_features.distance import ( + DistanceFootprintMeanFeature, + DistanceFootprintMeanNeighbourhoodFeature, + DistanceFootprintMinimumFeature, + DistanceFootprintMinimumNeighbourhoodFeature, + DistanceTssMeanFeature, + DistanceTssMeanNeighbourhoodFeature, + DistanceTssMinimumFeature, + DistanceTssMinimumNeighbourhoodFeature, +) +from gentropy.dataset.l2g_features.l2g_feature import L2GFeature from gentropy.dataset.l2g_gold_standard import L2GGoldStandard from gentropy.dataset.study_locus import StudyLocus diff --git a/tests/gentropy/dataset/test_l2g_feature.py b/tests/gentropy/dataset/test_l2g_feature.py index b77f0d103..43705b663 100644 --- a/tests/gentropy/dataset/test_l2g_feature.py +++ b/tests/gentropy/dataset/test_l2g_feature.py @@ -16,20 +16,11 @@ ) from gentropy.dataset.colocalisation import Colocalisation -from gentropy.dataset.l2g_feature import ( - DistanceFootprintMeanFeature, - DistanceFootprintMeanNeighbourhoodFeature, - DistanceFootprintMinimumFeature, - DistanceFootprintMinimumNeighbourhoodFeature, - DistanceTssMeanFeature, - DistanceTssMeanNeighbourhoodFeature, - DistanceTssMinimumFeature, - DistanceTssMinimumNeighbourhoodFeature, +from gentropy.dataset.l2g_features.colocalisation import ( EQtlColocClppMaximumFeature, EQtlColocClppMaximumNeighbourhoodFeature, EQtlColocH4MaximumFeature, EQtlColocH4MaximumNeighbourhoodFeature, - L2GFeature, PQtlColocClppMaximumFeature, PQtlColocClppMaximumNeighbourhoodFeature, PQtlColocH4MaximumFeature, @@ -43,10 +34,21 @@ TuQtlColocH4MaximumFeature, TuQtlColocH4MaximumNeighbourhoodFeature, _common_colocalisation_feature_logic, - _common_distance_feature_logic, _common_neighbourhood_colocalisation_feature_logic, +) +from gentropy.dataset.l2g_features.distance import ( + DistanceFootprintMeanFeature, + DistanceFootprintMeanNeighbourhoodFeature, + DistanceFootprintMinimumFeature, + DistanceFootprintMinimumNeighbourhoodFeature, + DistanceTssMeanFeature, + DistanceTssMeanNeighbourhoodFeature, + DistanceTssMinimumFeature, + DistanceTssMinimumNeighbourhoodFeature, + _common_distance_feature_logic, _common_neighbourhood_distance_feature_logic, ) +from gentropy.dataset.l2g_features.l2g_feature import L2GFeature from gentropy.dataset.study_index import StudyIndex from gentropy.dataset.study_locus import StudyLocus from gentropy.dataset.variant_index import VariantIndex From d0a9126df4803fbea51c2350401b1bfeac0a9aa6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Irene=20L=C3=B3pez?= Date: Wed, 25 Sep 2024 15:25:26 +0100 Subject: [PATCH 41/48] fix: documentation references --- .../datasets/l2g_features/_l2g_feature.md | 2 +- .../datasets/l2g_features/colocalisation.md | 38 +++++++++---------- .../datasets/l2g_features/distance.md | 22 +++++------ docs/python_api/methods/l2g/_l2g.md | 7 +--- .../dataset/l2g_features/colocalisation.py | 38 +++++++++---------- tests/gentropy/dataset/test_l2g_feature.py | 20 +++++----- 6 files changed, 62 insertions(+), 65 deletions(-) diff --git a/docs/python_api/datasets/l2g_features/_l2g_feature.md b/docs/python_api/datasets/l2g_features/_l2g_feature.md index 8ac17d530..b2f6f8187 100644 --- a/docs/python_api/datasets/l2g_features/_l2g_feature.md +++ b/docs/python_api/datasets/l2g_features/_l2g_feature.md @@ -4,7 +4,7 @@ title: L2G Feature ## Abstract Class -::: gentropy.dataset.l2g_features.L2GFeature +::: gentropy.dataset.l2g_features.l2g_feature.L2GFeature ## Schema diff --git a/docs/python_api/datasets/l2g_features/colocalisation.md b/docs/python_api/datasets/l2g_features/colocalisation.md index a4b2a5506..a3928c4ab 100644 --- a/docs/python_api/datasets/l2g_features/colocalisation.md +++ b/docs/python_api/datasets/l2g_features/colocalisation.md @@ -1,27 +1,27 @@ --- -title: Colocalisation derived features +title: From colocalisation --- ## List of features -::: gentropy.dataset.l2g_features.EQtlColocClppMaximumFeature -::: gentropy.dataset.l2g_features.PQtlColocClppMaximumFeature -::: gentropy.dataset.l2g_features.SQtlColocClppMaximumFeature -::: gentropy.dataset.l2g_features.TuQtlColocClppMaximumFeature -::: gentropy.dataset.l2g_features.EQtlColocH4MaximumFeature -::: gentropy.dataset.l2g_features.PQtlColocH4MaximumFeature -::: gentropy.dataset.l2g_features.SQtlColocH4MaximumFeature -::: gentropy.dataset.l2g_features.TuQtlColocH4MaximumFeature -::: gentropy.dataset.l2g_features.EQtlColocClppMaximumNeighbourhoodFeature -::: gentropy.dataset.l2g_features.PQtlColocClppMaximumNeighbourhoodFeature -::: gentropy.dataset.l2g_features.SQtlColocClppMaximumNeighbourhoodFeature -::: gentropy.dataset.l2g_features.TuQtlColocClppMaximumNeighbourhoodFeature -::: gentropy.dataset.l2g_features.EQtlColocH4MaximumNeighbourhoodFeature -::: gentropy.dataset.l2g_features.PQtlColocH4MaximumNeighbourhoodFeature -::: gentropy.dataset.l2g_features.SQtlColocH4MaximumNeighbourhoodFeature -::: gentropy.dataset.l2g_features.TuQtlColocH4MaximumNeighbourhoodFeature +::: gentropy.dataset.l2g_features.colocalisation.EQtlColocClppMaximumFeature +::: gentropy.dataset.l2g_features.colocalisation.PQtlColocClppMaximumFeature +::: gentropy.dataset.l2g_features.colocalisation.SQtlColocClppMaximumFeature +::: gentropy.dataset.l2g_features.colocalisation.TuQtlColocClppMaximumFeature +::: gentropy.dataset.l2g_features.colocalisation.EQtlColocH4MaximumFeature +::: gentropy.dataset.l2g_features.colocalisation.PQtlColocH4MaximumFeature +::: gentropy.dataset.l2g_features.colocalisation.SQtlColocH4MaximumFeature +::: gentropy.dataset.l2g_features.colocalisation.TuQtlColocH4MaximumFeature +::: gentropy.dataset.l2g_features.colocalisation.EQtlColocClppMaximumNeighbourhoodFeature +::: gentropy.dataset.l2g_features.colocalisation.PQtlColocClppMaximumNeighbourhoodFeature +::: gentropy.dataset.l2g_features.colocalisation.SQtlColocClppMaximumNeighbourhoodFeature +::: gentropy.dataset.l2g_features.colocalisation.TuQtlColocClppMaximumNeighbourhoodFeature +::: gentropy.dataset.l2g_features.colocalisation.EQtlColocH4MaximumNeighbourhoodFeature +::: gentropy.dataset.l2g_features.colocalisation.PQtlColocH4MaximumNeighbourhoodFeature +::: gentropy.dataset.l2g_features.colocalisation.SQtlColocH4MaximumNeighbourhoodFeature +::: gentropy.dataset.l2g_features.colocalisation.TuQtlColocH4MaximumNeighbourhoodFeature ## Common logic -::: gentropy.dataset.l2g_features.\_common_colocalisation_feature_logic -::: gentropy.dataset.l2g_features.\_common_neighbourhood_colocalisation_feature_logic +::: gentropy.dataset.l2g_features.colocalisation.common_colocalisation_feature_logic +::: gentropy.dataset.l2g_features.colocalisation.common_neighbourhood_colocalisation_feature_logic diff --git a/docs/python_api/datasets/l2g_features/distance.md b/docs/python_api/datasets/l2g_features/distance.md index d41a39a01..af432a6e5 100644 --- a/docs/python_api/datasets/l2g_features/distance.md +++ b/docs/python_api/datasets/l2g_features/distance.md @@ -1,19 +1,19 @@ --- -title: Distance derived features +title: From distance --- ## List of features -::: gentropy.dataset.l2g_features.DistanceTssMinimumFeature -::: gentropy.dataset.l2g_features.DistanceTssMeanFeature -::: gentropy.dataset.l2g_features.DistanceFootprintMinimumFeature -::: gentropy.dataset.l2g_features.DistanceFootprintMeanFeature -::: gentropy.dataset.l2g_features.DistanceTssMinimumNeighbourhoodFeature -::: gentropy.dataset.l2g_features.DistanceTssMeanNeighbourhoodFeature -::: gentropy.dataset.l2g_features.DistanceFootprintMinimumNeighbourhoodFeature -::: gentropy.dataset.l2g_features.DistanceFootprintMeanNeighbourhoodFeature +::: gentropy.dataset.l2g_features.distance.DistanceTssMinimumFeature +::: gentropy.dataset.l2g_features.distance.DistanceTssMeanFeature +::: gentropy.dataset.l2g_features.distance.DistanceFootprintMinimumFeature +::: gentropy.dataset.l2g_features.distance.DistanceFootprintMeanFeature +::: gentropy.dataset.l2g_features.distance.DistanceTssMinimumNeighbourhoodFeature +::: gentropy.dataset.l2g_features.distance.DistanceTssMeanNeighbourhoodFeature +::: gentropy.dataset.l2g_features.distance.DistanceFootprintMinimumNeighbourhoodFeature +::: gentropy.dataset.l2g_features.distance.DistanceFootprintMeanNeighbourhoodFeature ## Common logic -::: gentropy.dataset.l2g_features.\_common_distance_feature_logic -::: gentropy.dataset.l2g_features.\_common_neighbourhood_distance_feature_logic +::: gentropy.dataset.l2g_features.distance.common_distance_feature_logic +::: gentropy.dataset.l2g_features.distance.common_neighbourhood_distance_feature_logic diff --git a/docs/python_api/methods/l2g/_l2g.md b/docs/python_api/methods/l2g/_l2g.md index fca3ba79d..bbd7dad66 100644 --- a/docs/python_api/methods/l2g/_l2g.md +++ b/docs/python_api/methods/l2g/_l2g.md @@ -9,13 +9,10 @@ The **“locus-to-gene” (L2G)** model derives features to prioritize likely ca - **Chromatin Interaction:** (e.g., promoter-capture Hi-C) - **Variant Pathogenicity:** (from VEP) -The L2G model is distinct from the variant-to-gene (V2G) pipeline in that it: - -- Uses a machine-learning model to learn the weights of each evidence source based on a gold standard of previously identified causal genes. -- Relies upon fine-mapping and colocalization data. - Some of the predictive features weight variant-to-gene (or genomic region-to-gene) evidence based on the posterior probability that the variant is causal, determined through fine-mapping of the GWAS association. +For a more detailed description of how each feature is computed, see [the L2G Feature documentation](../../datasets/l2g_features/_l2g_feature.md). + Details of the L2G model are provided in our Nature Genetics publication (ref - [Nature Genetics Publication](https://www.nature.com/articles/s41588-021-00945-5)): - **Title:** An open approach to systematically prioritize causal variants and genes at all published human GWAS trait-associated loci. diff --git a/src/gentropy/dataset/l2g_features/colocalisation.py b/src/gentropy/dataset/l2g_features/colocalisation.py index fe4fb065b..c44573b72 100644 --- a/src/gentropy/dataset/l2g_features/colocalisation.py +++ b/src/gentropy/dataset/l2g_features/colocalisation.py @@ -18,7 +18,7 @@ from pyspark.sql import DataFrame -def _common_colocalisation_feature_logic( +def common_colocalisation_feature_logic( study_loci_to_annotate: StudyLocus | L2GGoldStandard, colocalisation_method: str, colocalisation_metric: str, @@ -68,7 +68,7 @@ def _common_colocalisation_feature_logic( ) -def _common_neighbourhood_colocalisation_feature_logic( +def common_neighbourhood_colocalisation_feature_logic( study_loci_to_annotate: StudyLocus | L2GGoldStandard, colocalisation_method: str, colocalisation_metric: str, @@ -96,7 +96,7 @@ def _common_neighbourhood_colocalisation_feature_logic( """ # First maximum colocalisation score for each studylocus, gene local_feature_name = feature_name.replace("Neighbourhood", "") - local_max = _common_colocalisation_feature_logic( + local_max = common_colocalisation_feature_logic( study_loci_to_annotate, colocalisation_method, colocalisation_metric, @@ -144,7 +144,7 @@ def compute( return cls( _df=convert_from_wide_to_long( - _common_colocalisation_feature_logic( + common_colocalisation_feature_logic( study_loci_to_annotate, colocalisation_method, colocalisation_metric, @@ -187,7 +187,7 @@ def compute( return cls( _df=convert_from_wide_to_long( - _common_neighbourhood_colocalisation_feature_logic( + common_neighbourhood_colocalisation_feature_logic( study_loci_to_annotate, colocalisation_method, colocalisation_metric, @@ -229,7 +229,7 @@ def compute( qtl_type = "pqtl" return cls( _df=convert_from_wide_to_long( - _common_colocalisation_feature_logic( + common_colocalisation_feature_logic( study_loci_to_annotate, colocalisation_method, colocalisation_metric, @@ -271,7 +271,7 @@ def compute( qtl_type = "pqtl" return cls( _df=convert_from_wide_to_long( - _common_neighbourhood_colocalisation_feature_logic( + common_neighbourhood_colocalisation_feature_logic( study_loci_to_annotate, colocalisation_method, colocalisation_metric, @@ -313,7 +313,7 @@ def compute( qtl_type = "sqtl" return cls( _df=convert_from_wide_to_long( - _common_colocalisation_feature_logic( + common_colocalisation_feature_logic( study_loci_to_annotate, colocalisation_method, colocalisation_metric, @@ -355,7 +355,7 @@ def compute( qtl_type = "sqtl" return cls( _df=convert_from_wide_to_long( - _common_neighbourhood_colocalisation_feature_logic( + common_neighbourhood_colocalisation_feature_logic( study_loci_to_annotate, colocalisation_method, colocalisation_metric, @@ -397,7 +397,7 @@ def compute( qtl_type = "tuqtl" return cls( _df=convert_from_wide_to_long( - _common_colocalisation_feature_logic( + common_colocalisation_feature_logic( study_loci_to_annotate, colocalisation_method, colocalisation_metric, @@ -439,7 +439,7 @@ def compute( qtl_type = "tuqtl" return cls( _df=convert_from_wide_to_long( - _common_neighbourhood_colocalisation_feature_logic( + common_neighbourhood_colocalisation_feature_logic( study_loci_to_annotate, colocalisation_method, colocalisation_metric, @@ -481,7 +481,7 @@ def compute( qtl_type = "eqtl" return cls( _df=convert_from_wide_to_long( - _common_colocalisation_feature_logic( + common_colocalisation_feature_logic( study_loci_to_annotate, colocalisation_method, colocalisation_metric, @@ -523,7 +523,7 @@ def compute( qtl_type = "eqtl" return cls( _df=convert_from_wide_to_long( - _common_neighbourhood_colocalisation_feature_logic( + common_neighbourhood_colocalisation_feature_logic( study_loci_to_annotate, colocalisation_method, colocalisation_metric, @@ -565,7 +565,7 @@ def compute( qtl_type = "pqtl" return cls( _df=convert_from_wide_to_long( - _common_colocalisation_feature_logic( + common_colocalisation_feature_logic( study_loci_to_annotate, colocalisation_method, colocalisation_metric, @@ -607,7 +607,7 @@ def compute( qtl_type = "pqtl" return cls( _df=convert_from_wide_to_long( - _common_neighbourhood_colocalisation_feature_logic( + common_neighbourhood_colocalisation_feature_logic( study_loci_to_annotate, colocalisation_method, colocalisation_metric, @@ -649,7 +649,7 @@ def compute( qtl_type = "sqtl" return cls( _df=convert_from_wide_to_long( - _common_colocalisation_feature_logic( + common_colocalisation_feature_logic( study_loci_to_annotate, colocalisation_method, colocalisation_metric, @@ -691,7 +691,7 @@ def compute( qtl_type = "sqtl" return cls( _df=convert_from_wide_to_long( - _common_neighbourhood_colocalisation_feature_logic( + common_neighbourhood_colocalisation_feature_logic( study_loci_to_annotate, colocalisation_method, colocalisation_metric, @@ -733,7 +733,7 @@ def compute( qtl_type = "tuqtl" return cls( _df=convert_from_wide_to_long( - _common_colocalisation_feature_logic( + common_colocalisation_feature_logic( study_loci_to_annotate, colocalisation_method, colocalisation_metric, @@ -775,7 +775,7 @@ def compute( qtl_type = "tuqtl" return cls( _df=convert_from_wide_to_long( - _common_colocalisation_feature_logic( + common_colocalisation_feature_logic( study_loci_to_annotate, colocalisation_method, colocalisation_metric, diff --git a/tests/gentropy/dataset/test_l2g_feature.py b/tests/gentropy/dataset/test_l2g_feature.py index 43705b663..44e715ed7 100644 --- a/tests/gentropy/dataset/test_l2g_feature.py +++ b/tests/gentropy/dataset/test_l2g_feature.py @@ -33,8 +33,8 @@ TuQtlColocClppMaximumNeighbourhoodFeature, TuQtlColocH4MaximumFeature, TuQtlColocH4MaximumNeighbourhoodFeature, - _common_colocalisation_feature_logic, - _common_neighbourhood_colocalisation_feature_logic, + common_colocalisation_feature_logic, + common_neighbourhood_colocalisation_feature_logic, ) from gentropy.dataset.l2g_features.distance import ( DistanceFootprintMeanFeature, @@ -45,8 +45,8 @@ DistanceTssMeanNeighbourhoodFeature, DistanceTssMinimumFeature, DistanceTssMinimumNeighbourhoodFeature, - _common_distance_feature_logic, - _common_neighbourhood_distance_feature_logic, + common_distance_feature_logic, + common_neighbourhood_distance_feature_logic, ) from gentropy.dataset.l2g_features.l2g_feature import L2GFeature from gentropy.dataset.study_index import StudyIndex @@ -123,7 +123,7 @@ def test__common_colocalisation_feature_logic( The H4 value of number 2 is higher, therefore the feature value should be based on that. """ feature_name = "eQtlColocH4Maximum" - observed_df = _common_colocalisation_feature_logic( + observed_df = common_colocalisation_feature_logic( self.sample_study_loci_to_annotate, self.colocalisation_method, self.colocalisation_metric, @@ -156,7 +156,7 @@ def test__common_neighbourhood_colocalisation_feature_logic( ) -> None: """Test the common logic of the neighbourhood colocalisation features.""" feature_name = "eQtlColocH4MaximumNeighbourhood" - observed_df = _common_neighbourhood_colocalisation_feature_logic( + observed_df = common_neighbourhood_colocalisation_feature_logic( self.sample_study_loci_to_annotate, self.colocalisation_method, self.colocalisation_metric, @@ -314,7 +314,7 @@ class TestCommonDistanceFeatureLogic: ("distanceTssMean", 3.75), ], ) - def test__common_distance_feature_logic( + def test_common_distance_feature_logic( self: TestCommonDistanceFeatureLogic, spark: SparkSession, feature_name: str, @@ -326,7 +326,7 @@ def test__common_distance_feature_logic( if feature_name == "distanceTssMinimum" else f.mean(f.col("weightedDistance")) ) - observed_df = _common_distance_feature_logic( + observed_df = common_distance_feature_logic( self.sample_study_locus, variant_index=self.sample_variant_index, feature_name=feature_name, @@ -335,7 +335,7 @@ def test__common_distance_feature_logic( ) assert observed_df.first()[feature_name] == expected_distance - def test__common_neighbourhood_colocalisation_feature_logic( + def test_common_neighbourhood_colocalisation_feature_logic( self: TestCommonDistanceFeatureLogic, spark: SparkSession, ) -> None: @@ -369,7 +369,7 @@ def test__common_neighbourhood_colocalisation_feature_logic( ), _schema=VariantIndex.get_schema(), ) - observed_df = _common_neighbourhood_distance_feature_logic( + observed_df = common_neighbourhood_distance_feature_logic( self.sample_study_locus, variant_index=another_sample_variant_index, feature_name="distanceTssMinimum", From bb47b01dcd2213b9a40e15abcdcdd1212163b2ed Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Irene=20L=C3=B3pez?= Date: Fri, 27 Sep 2024 13:43:04 +0100 Subject: [PATCH 42/48] feat: implement distance to sentinel and adapt definitions --- .../datasets/l2g_features/distance.md | 8 +- src/gentropy/config.py | 10 ++ src/gentropy/dataset/l2g_features/distance.py | 155 +++++++++--------- src/gentropy/method/l2g/feature_factory.py | 16 +- tests/gentropy/conftest.py | 2 +- tests/gentropy/dataset/test_l2g.py | 2 +- tests/gentropy/dataset/test_l2g_feature.py | 133 +++++++-------- 7 files changed, 165 insertions(+), 161 deletions(-) diff --git a/docs/python_api/datasets/l2g_features/distance.md b/docs/python_api/datasets/l2g_features/distance.md index af432a6e5..56e9069b3 100644 --- a/docs/python_api/datasets/l2g_features/distance.md +++ b/docs/python_api/datasets/l2g_features/distance.md @@ -4,13 +4,13 @@ title: From distance ## List of features -::: gentropy.dataset.l2g_features.distance.DistanceTssMinimumFeature +::: gentropy.dataset.l2g_features.distance.DistanceSentinelTssMinimumFeature ::: gentropy.dataset.l2g_features.distance.DistanceTssMeanFeature -::: gentropy.dataset.l2g_features.distance.DistanceFootprintMinimumFeature +::: gentropy.dataset.l2g_features.distance.DistanceSentinelFootprintMinimumFeature ::: gentropy.dataset.l2g_features.distance.DistanceFootprintMeanFeature -::: gentropy.dataset.l2g_features.distance.DistanceTssMinimumNeighbourhoodFeature +::: gentropy.dataset.l2g_features.distance.DistanceTssFootprintMinimumNeighbourhoodFeature ::: gentropy.dataset.l2g_features.distance.DistanceTssMeanNeighbourhoodFeature -::: gentropy.dataset.l2g_features.distance.DistanceFootprintMinimumNeighbourhoodFeature +::: gentropy.dataset.l2g_features.distance.DistanceSentinelFootprintMinimumNeighbourhoodFeature ::: gentropy.dataset.l2g_features.distance.DistanceFootprintMeanNeighbourhoodFeature ## Common logic diff --git a/src/gentropy/config.py b/src/gentropy/config.py index 180bca1f7..7801794c4 100644 --- a/src/gentropy/config.py +++ b/src/gentropy/config.py @@ -248,6 +248,16 @@ class LocusToGeneConfig(StepConfig): "pQtlColocH4Maximum", "sQtlColocH4Maximum", "tuQtlColocH4Maximum", + # distance to gene footprint + "distanceSentinelFootprint", + "distanceSentinelFootprintNeighbourhood", + "distanceFootprintMean", + "distanceFootprintMeanNeighbourhood", + # distance to gene tss + "distanceTssMean", + "distanceTssMeanNeighbourhood", + "distanceSentinelTss", + "distanceSentinelTssNeighbourhood", ] ) hyperparameters: dict[str, Any] = field( diff --git a/src/gentropy/dataset/l2g_features/distance.py b/src/gentropy/dataset/l2g_features/distance.py index 8773732f9..ea030108c 100644 --- a/src/gentropy/dataset/l2g_features/distance.py +++ b/src/gentropy/dataset/l2g_features/distance.py @@ -14,92 +14,99 @@ from gentropy.dataset.variant_index import VariantIndex if TYPE_CHECKING: - from pyspark.sql import Column, DataFrame + from pyspark.sql import DataFrame -def _common_distance_feature_logic( +def common_distance_feature_logic( study_loci_to_annotate: StudyLocus | L2GGoldStandard, *, variant_index: VariantIndex, feature_name: str, distance_type: str, - agg_expr: Column, + genomic_window: int = 500_000, ) -> DataFrame: - """Computes the feature. + """Calculate the distance feature that correlates a variant in a credible set with a gene. + + The distance is weighted by the posterior probability of the variant to factor in its contribution to the trait when we look at the average distance score for all variants in the credible set. Args: study_loci_to_annotate (StudyLocus | L2GGoldStandard): The dataset containing study loci that will be used for annotation variant_index (VariantIndex): The dataset containing distance to gene information feature_name (str): The name of the feature distance_type (str): The type of distance to gene - agg_expr (Column): The expression that aggregate distances into a specific way to define the feature + genomic_window (int): The maximum window size to consider Returns: DataFrame: Feature dataset """ distances_dataset = variant_index.get_distance_to_gene(distance_type=distance_type) - return ( - study_loci_to_annotate.df.withColumn("variantInLocus", f.explode_outer("locus")) - .select( + if "Mean" in feature_name: + # Weighting by the SNP contribution is only applied when we are averaging all distances + distance_score_expr = ( + f.lit(genomic_window) - f.col(distance_type) + f.lit(1) + ) * f.col("posteriorProbability") + agg_expr = f.mean(f.col("distance_score")) + df = study_loci_to_annotate.df.withColumn( + "variantInLocus", f.explode_outer("locus") + ).select( "studyLocusId", - f.col("variantInLocus.variantId").alias("variantInLocusId"), - f.col("variantInLocus.posteriorProbability").alias( - "variantInLocusPosteriorProbability" - ), + f.col("variantInLocus.variantId").alias("variantId"), + f.col("variantInLocus.posteriorProbability").alias("posteriorProbability"), ) - .join( - distances_dataset.withColumnRenamed( - "variantId", "variantInLocusId" - ).withColumnRenamed("targetId", "geneId"), - on="variantInLocusId", + elif "Sentinel" in feature_name: + # For minimum distances we calculate the unweighted distance between the sentinel (lead) and the gene. This + distance_score_expr = f.lit(genomic_window) - f.col(distance_type) + f.lit(1) + agg_expr = f.first(f.col("distance_score")) + df = study_loci_to_annotate.df.select("studyLocusId", "variantId") + return ( + df.join( + distances_dataset.withColumnRenamed("targetId", "geneId"), + on="variantId", how="inner", ) - .withColumn( - "weightedDistance", - f.col(distance_type) * f.col("variantInLocusPosteriorProbability"), - ) + .withColumn("distance_score", f.log10(distance_score_expr)) .groupBy("studyLocusId", "geneId") .agg(agg_expr.alias(feature_name)) ) -def _common_neighbourhood_distance_feature_logic( +def common_neighbourhood_distance_feature_logic( study_loci_to_annotate: StudyLocus | L2GGoldStandard, *, variant_index: VariantIndex, feature_name: str, distance_type: str, - agg_expr: Column, + genomic_window: int = 500_000, ) -> DataFrame: - """Calculate the neighbourhood distance feature. + """Calculate the distance feature that correlates any variant in a credible set with any gene nearby the locus. The distance is weighted by the posterior probability of the variant to factor in its contribution to the trait. Args: study_loci_to_annotate (StudyLocus | L2GGoldStandard): The dataset containing study loci that will be used for annotation variant_index (VariantIndex): The dataset containing distance to gene information feature_name (str): The name of the feature distance_type (str): The type of distance to gene - agg_expr (Column): The expression that aggregate distances into a specific way to define the feature + genomic_window (int): The maximum window size to consider Returns: - DataFrame: Feature dataset + DataFrame: Feature dataset """ local_feature_name = feature_name.replace("Neighbourhood", "") # First compute mean distances to a gene - local_min = _common_distance_feature_logic( + local_metric = common_distance_feature_logic( study_loci_to_annotate, feature_name=local_feature_name, distance_type=distance_type, - agg_expr=agg_expr, variant_index=variant_index, + genomic_window=genomic_window, ) return ( - # Then compute minimum distance in the vicinity (feature will be the same for any gene associated with a studyLocus) - local_min.withColumn( - "regional_minimum", - f.min(local_feature_name).over(Window.partitionBy("studyLocusId")), + # Then compute mean distance in the vicinity (feature will be the same for any gene associated with a studyLocus) + local_metric.withColumn( + "regional_metric", + f.mean(f.col(local_feature_name)).over(Window.partitionBy("studyLocusId")), ) - .withColumn(feature_name, f.col("regional_minimum") - f.col(local_feature_name)) - .drop("regional_minimum") + .withColumn(feature_name, f.col(local_feature_name) - f.col("regional_metric")) + .drop("regional_metric", local_feature_name) ) @@ -125,15 +132,13 @@ def compute( Returns: DistanceTssMeanFeature: Feature dataset """ - agg_expr = f.mean("weightedDistance") distance_type = "distanceFromTss" return cls( _df=convert_from_wide_to_long( - _common_distance_feature_logic( + common_distance_feature_logic( study_loci_to_annotate, feature_name=cls.feature_name, distance_type=distance_type, - agg_expr=agg_expr, **feature_dependency, ), id_vars=("studyLocusId", "geneId"), @@ -166,15 +171,13 @@ def compute( Returns: DistanceTssMeanNeighbourhoodFeature: Feature dataset """ - agg_expr = f.mean("weightedDistance") distance_type = "distanceFromTss" return cls( _df=convert_from_wide_to_long( - _common_neighbourhood_distance_feature_logic( + common_neighbourhood_distance_feature_logic( study_loci_to_annotate, feature_name=cls.feature_name, distance_type=distance_type, - agg_expr=agg_expr, **feature_dependency, ), id_vars=("studyLocusId", "geneId"), @@ -185,19 +188,19 @@ def compute( ) -class DistanceTssMinimumFeature(L2GFeature): - """Minimum distance of all tagging variants to gene TSS.""" +class DistanceSentinelTssFeature(L2GFeature): + """Distance of the sentinel variant to gene TSS. This is not weighted by the causal probability.""" fill_na_value = 500_000 feature_dependency_type = VariantIndex - feature_name = "distanceTssMinimum" + feature_name = "distanceSentinelTss" @classmethod def compute( - cls: type[DistanceTssMinimumFeature], + cls: type[DistanceSentinelTssFeature], study_loci_to_annotate: StudyLocus | L2GGoldStandard, feature_dependency: dict[str, Any], - ) -> DistanceTssMinimumFeature: + ) -> DistanceSentinelTssFeature: """Computes the feature. Args: @@ -205,17 +208,15 @@ def compute( feature_dependency (dict[str, Any]): Dataset that contains the distance information Returns: - DistanceTssMinimumFeature: Feature dataset + DistanceSentinelTssFeature: Feature dataset """ - agg_expr = f.mean("weightedDistance") distance_type = "distanceFromTss" return cls( _df=convert_from_wide_to_long( - _common_distance_feature_logic( + common_distance_feature_logic( study_loci_to_annotate, feature_name=cls.feature_name, distance_type=distance_type, - agg_expr=agg_expr, **feature_dependency, ), id_vars=("studyLocusId", "geneId"), @@ -226,19 +227,19 @@ def compute( ) -class DistanceTssMinimumNeighbourhoodFeature(L2GFeature): - """Minimum minimum distance to TSS for all genes in the vicinity of a studyLocus.""" +class DistanceSentinelTssNeighbourhoodFeature(L2GFeature): + """Distance between the sentinel variant and a gene TSS as a relation of the distnace with all the genes in the vicinity of a studyLocus. This is not weighted by the causal probability.""" fill_na_value = 500_000 feature_dependency_type = VariantIndex - feature_name = "distanceTssMinimumNeighbourhood" + feature_name = "distanceSentinelTssNeighbourhood" @classmethod def compute( - cls: type[DistanceTssMinimumNeighbourhoodFeature], + cls: type[DistanceSentinelTssNeighbourhoodFeature], study_loci_to_annotate: StudyLocus | L2GGoldStandard, feature_dependency: dict[str, Any], - ) -> DistanceTssMinimumNeighbourhoodFeature: + ) -> DistanceSentinelTssNeighbourhoodFeature: """Computes the feature. Args: @@ -246,17 +247,15 @@ def compute( feature_dependency (dict[str, Any]): Dataset that contains the distance information Returns: - DistanceTssMinimumNeighbourhoodFeature: Feature dataset + DistanceSentinelTssNeighbourhoodFeature: Feature dataset """ - agg_expr = f.min("weightedDistance") distance_type = "distanceFromTss" return cls( _df=convert_from_wide_to_long( - _common_neighbourhood_distance_feature_logic( + common_neighbourhood_distance_feature_logic( study_loci_to_annotate, feature_name=cls.feature_name, distance_type=distance_type, - agg_expr=agg_expr, **feature_dependency, ), id_vars=("studyLocusId", "geneId"), @@ -289,15 +288,13 @@ def compute( Returns: DistanceFootprintMeanFeature: Feature dataset """ - agg_expr = f.mean("weightedDistance") distance_type = "distanceFromFootprint" return cls( _df=convert_from_wide_to_long( - _common_distance_feature_logic( + common_distance_feature_logic( study_loci_to_annotate, feature_name=cls.feature_name, distance_type=distance_type, - agg_expr=agg_expr, **feature_dependency, ), id_vars=("studyLocusId", "geneId"), @@ -330,15 +327,13 @@ def compute( Returns: DistanceFootprintMeanNeighbourhoodFeature: Feature dataset """ - agg_expr = f.mean("weightedDistance") distance_type = "distanceFromFootprint" return cls( _df=convert_from_wide_to_long( - _common_neighbourhood_distance_feature_logic( + common_neighbourhood_distance_feature_logic( study_loci_to_annotate, feature_name=cls.feature_name, distance_type=distance_type, - agg_expr=agg_expr, **feature_dependency, ), id_vars=("studyLocusId", "geneId"), @@ -349,19 +344,19 @@ def compute( ) -class DistanceFootprintMinimumFeature(L2GFeature): - """Minimum distance of all tagging variants to the footprint of a gene.""" +class DistanceSentinelFootprintFeature(L2GFeature): + """Distance between the sentinel variant and the footprint of a gene.""" fill_na_value = 500_000 feature_dependency_type = VariantIndex - feature_name = "DistanceFootprintMinimum" + feature_name = "distanceSentinelFootprintMinimum" @classmethod def compute( - cls: type[DistanceFootprintMinimumFeature], + cls: type[DistanceSentinelFootprintFeature], study_loci_to_annotate: StudyLocus | L2GGoldStandard, feature_dependency: dict[str, Any], - ) -> DistanceFootprintMinimumFeature: + ) -> DistanceSentinelFootprintFeature: """Computes the feature. Args: @@ -369,17 +364,15 @@ def compute( feature_dependency (dict[str, Any]): Dataset that contains the distance information Returns: - DistanceFootprintMinimumFeature: Feature dataset + DistanceSentinelFootprintFeature: Feature dataset """ - agg_expr = f.mean("weightedDistance") distance_type = "distanceFromFootprint" return cls( _df=convert_from_wide_to_long( - _common_distance_feature_logic( + common_distance_feature_logic( study_loci_to_annotate, feature_name=cls.feature_name, distance_type=distance_type, - agg_expr=agg_expr, **feature_dependency, ), id_vars=("studyLocusId", "geneId"), @@ -390,19 +383,19 @@ def compute( ) -class DistanceFootprintMinimumNeighbourhoodFeature(L2GFeature): - """Minimum minimum distance to footprint for all genes in the vicinity of a studyLocus.""" +class DistanceSentinelFootprintNeighbourhoodFeature(L2GFeature): + """Distance between the sentinel variant and a gene footprint as a relation of the distnace with all the genes in the vicinity of a studyLocus. This is not weighted by the causal probability.""" fill_na_value = 500_000 feature_dependency_type = VariantIndex - feature_name = "distanceFootprintMinimumNeighbourhood" + feature_name = "DistanceSentinelFootprintNeighbourhoodFeature" @classmethod def compute( - cls: type[DistanceFootprintMinimumNeighbourhoodFeature], + cls: type[DistanceSentinelFootprintNeighbourhoodFeature], study_loci_to_annotate: StudyLocus | L2GGoldStandard, feature_dependency: dict[str, Any], - ) -> DistanceFootprintMinimumNeighbourhoodFeature: + ) -> DistanceSentinelFootprintNeighbourhoodFeature: """Computes the feature. Args: @@ -410,17 +403,15 @@ def compute( feature_dependency (dict[str, Any]): Dataset that contains the distance information Returns: - DistanceFootprintMinimumNeighbourhoodFeature: Feature dataset + DistanceSentinelFootprintNeighbourhoodFeature: Feature dataset """ - agg_expr = f.min("weightedDistance") distance_type = "distanceFromFootprint" return cls( _df=convert_from_wide_to_long( - _common_neighbourhood_distance_feature_logic( + common_neighbourhood_distance_feature_logic( study_loci_to_annotate, feature_name=cls.feature_name, distance_type=distance_type, - agg_expr=agg_expr, **feature_dependency, ), id_vars=("studyLocusId", "geneId"), diff --git a/src/gentropy/method/l2g/feature_factory.py b/src/gentropy/method/l2g/feature_factory.py index 52d7ff76a..41084277f 100644 --- a/src/gentropy/method/l2g/feature_factory.py +++ b/src/gentropy/method/l2g/feature_factory.py @@ -25,12 +25,12 @@ from gentropy.dataset.l2g_features.distance import ( DistanceFootprintMeanFeature, DistanceFootprintMeanNeighbourhoodFeature, - DistanceFootprintMinimumFeature, - DistanceFootprintMinimumNeighbourhoodFeature, + DistanceSentinelFootprintFeature, + DistanceSentinelFootprintNeighbourhoodFeature, + DistanceSentinelTssFeature, + DistanceSentinelTssNeighbourhoodFeature, DistanceTssMeanFeature, DistanceTssMeanNeighbourhoodFeature, - DistanceTssMinimumFeature, - DistanceTssMinimumNeighbourhoodFeature, ) from gentropy.dataset.l2g_features.l2g_feature import L2GFeature from gentropy.dataset.l2g_gold_standard import L2GGoldStandard @@ -93,13 +93,13 @@ class FeatureFactory: """Factory class for creating features.""" feature_mapper: Mapping[str, type[L2GFeature]] = { - "distanceTssMinimum": DistanceTssMinimumFeature, + "distanceSentinelTss": DistanceSentinelTssFeature, + "distanceSentinelTssNeighbourhood": DistanceSentinelTssNeighbourhoodFeature, + "distanceSentinelFootprint": DistanceSentinelFootprintFeature, + "distanceSentinelFootprintNeighbourhood": DistanceSentinelFootprintNeighbourhoodFeature, "distanceTssMean": DistanceTssMeanFeature, "distanceTssMeanNeighbourhood": DistanceTssMeanNeighbourhoodFeature, - "distanceTssMinimumNeighbourhood": DistanceTssMinimumNeighbourhoodFeature, - "distanceFootprintMinimum": DistanceFootprintMinimumFeature, "distanceFootprintMean": DistanceFootprintMeanFeature, - "distanceFootprintMinimumNeighbourhood": DistanceFootprintMinimumNeighbourhoodFeature, "distanceFootprintMeanNeighbourhood": DistanceFootprintMeanNeighbourhoodFeature, "eQtlColocClppMaximum": EQtlColocClppMaximumFeature, "eQtlColocClppMaximumNeighbourhood": EQtlColocClppMaximumNeighbourhoodFeature, diff --git a/tests/gentropy/conftest.py b/tests/gentropy/conftest.py index 79a54ef61..3aa3a95a4 100644 --- a/tests/gentropy/conftest.py +++ b/tests/gentropy/conftest.py @@ -607,7 +607,7 @@ def mock_l2g_feature_matrix(spark: SparkSession) -> L2GFeatureMatrix: (1, "gene1", 100.0, None), (2, "gene2", 1000.0, 0.0), ], - "studyLocusId LONG, geneId STRING, distanceTssMean FLOAT, distanceTssMinimum FLOAT", + "studyLocusId LONG, geneId STRING, distanceTssMean FLOAT, distanceSentinelTssMinimum FLOAT", ), with_gold_standard=False, ) diff --git a/tests/gentropy/dataset/test_l2g.py b/tests/gentropy/dataset/test_l2g.py index 2523b97dd..7d90c1767 100644 --- a/tests/gentropy/dataset/test_l2g.py +++ b/tests/gentropy/dataset/test_l2g.py @@ -177,7 +177,7 @@ def test_calculate_feature_missingness_rate( spark: SparkSession, mock_l2g_feature_matrix: L2GFeatureMatrix ) -> None: """Test L2GFeatureMatrix.calculate_feature_missingness_rate.""" - expected_missingness = {"distanceTssMean": 0.0, "distanceTssMinimum": 1.0} + expected_missingness = {"distanceTssMean": 0.0, "distanceSentinelTssMinimum": 1.0} observed_missingness = mock_l2g_feature_matrix.calculate_feature_missingness_rate() assert isinstance(observed_missingness, dict) assert mock_l2g_feature_matrix.features_list is not None and len( diff --git a/tests/gentropy/dataset/test_l2g_feature.py b/tests/gentropy/dataset/test_l2g_feature.py index 44e715ed7..96ba5cee1 100644 --- a/tests/gentropy/dataset/test_l2g_feature.py +++ b/tests/gentropy/dataset/test_l2g_feature.py @@ -39,12 +39,12 @@ from gentropy.dataset.l2g_features.distance import ( DistanceFootprintMeanFeature, DistanceFootprintMeanNeighbourhoodFeature, - DistanceFootprintMinimumFeature, - DistanceFootprintMinimumNeighbourhoodFeature, + DistanceSentinelFootprintFeature, + DistanceSentinelFootprintNeighbourhoodFeature, + DistanceSentinelTssFeature, + DistanceSentinelTssNeighbourhoodFeature, DistanceTssMeanFeature, DistanceTssMeanNeighbourhoodFeature, - DistanceTssMinimumFeature, - DistanceTssMinimumNeighbourhoodFeature, common_distance_feature_logic, common_neighbourhood_distance_feature_logic, ) @@ -78,13 +78,13 @@ SQtlColocH4MaximumNeighbourhoodFeature, TuQtlColocH4MaximumNeighbourhoodFeature, DistanceTssMeanFeature, - DistanceTssMinimumFeature, - DistanceFootprintMeanFeature, - DistanceFootprintMinimumFeature, DistanceTssMeanNeighbourhoodFeature, - DistanceTssMinimumNeighbourhoodFeature, + DistanceFootprintMeanFeature, DistanceFootprintMeanNeighbourhoodFeature, - DistanceFootprintMinimumNeighbourhoodFeature, + DistanceSentinelTssFeature, + DistanceSentinelTssNeighbourhoodFeature, + DistanceSentinelFootprintFeature, + DistanceSentinelFootprintNeighbourhoodFeature, ], ) def test_feature_factory_return_type( @@ -214,6 +214,7 @@ def _setup(self: TestCommonColocalisationFeatureLogic, spark: SparkSession) -> N "colocalisationMethod": "COLOC", "numberColocalisingVariants": 1, "h4": 0.81, + "rightStudyType": "eqtl", }, { "leftStudyLocusId": 1, @@ -222,6 +223,7 @@ def _setup(self: TestCommonColocalisationFeatureLogic, spark: SparkSession) -> N "colocalisationMethod": "COLOC", "numberColocalisingVariants": 1, "h4": 0.50, + "rightStudyType": "eqtl", }, { "leftStudyLocusId": 1, @@ -230,6 +232,7 @@ def _setup(self: TestCommonColocalisationFeatureLogic, spark: SparkSession) -> N "colocalisationMethod": "COLOC", "numberColocalisingVariants": 1, "h4": 0.90, + "rightStudyType": "eqtl", }, ], schema=Colocalisation.get_schema(), @@ -308,78 +311,77 @@ class TestCommonDistanceFeatureLogic: """Test the CommonDistanceFeatureLogic methods.""" @pytest.mark.parametrize( - ("feature_name", "expected_distance"), + ("feature_name", "expected_data"), [ - ("distanceTssMinimum", 2.5), - ("distanceTssMean", 3.75), + ( + "distanceSentinelTss", + [ + {"studyLocusId": 1, "geneId": "gene1", "distanceSentinelTss": 0.0}, + {"studyLocusId": 1, "geneId": "gene2", "distanceSentinelTss": 0.95}, + ], + ), + ( + "distanceTssMean", + [ + {"studyLocusId": 1, "geneId": "gene1", "distanceTssMean": 0.09}, + {"studyLocusId": 1, "geneId": "gene2", "distanceTssMean": 0.65}, + ], + ), ], ) def test_common_distance_feature_logic( self: TestCommonDistanceFeatureLogic, spark: SparkSession, feature_name: str, - expected_distance: int, + expected_data: dict[str, Any], ) -> None: - """Test the logic of the function that extracts the distance between the variants in a credible set and a gene.""" - agg_expr = ( - f.min(f.col("weightedDistance")) - if feature_name == "distanceTssMinimum" - else f.mean(f.col("weightedDistance")) + """Test the logic of the function that extracts features from distance. + + 2 tests: + - distanceSentinelTss: distance of the sentinel is 10, the max distance is 10. In log scale, the score is 0. + - distanceTssMean: avg distance of any variant in the credible set, weighted by its posterior. + """ + observed_df = ( + common_distance_feature_logic( + self.sample_study_locus, + variant_index=self.sample_variant_index, + feature_name=feature_name, + distance_type=self.distance_type, + genomic_window=10, + ) + .withColumn(feature_name, f.round(f.col(feature_name), 2)) + .orderBy(feature_name) ) - observed_df = common_distance_feature_logic( - self.sample_study_locus, - variant_index=self.sample_variant_index, - feature_name=feature_name, - distance_type=self.distance_type, - agg_expr=agg_expr, + expected_df = ( + spark.createDataFrame(expected_data) + .select("studyLocusId", "geneId", feature_name) + .orderBy(feature_name) ) - assert observed_df.first()[feature_name] == expected_distance + assert ( + observed_df.collect() == expected_df.collect() + ), f"Expected and observed dataframes are not equal for feature {feature_name}." def test_common_neighbourhood_colocalisation_feature_logic( self: TestCommonDistanceFeatureLogic, spark: SparkSession, ) -> None: - """Test the logic of the function that extracts the distance between the variants in a credible set and the nearby genes.""" - another_sample_variant_index = VariantIndex( - _df=spark.createDataFrame( - [ - ( - "lead1", - "chrom", - 1, - "A", - "T", - [ - {"distanceFromTss": 10, "targetId": "gene1"}, - {"distanceFromTss": 100, "targetId": "gene2"}, - ], - ), - ( - "tag1", - "chrom", - 1, - "A", - "T", - [ - {"distanceFromTss": 5, "targetId": "gene1"}, - ], - ), - ], - self.variant_index_schema, - ), - _schema=VariantIndex.get_schema(), + """Test the logic of the function that extracts the distance between the sentinel of a credible set and the nearby genes.""" + feature_name = "distanceSentinelTssNeighbourhood" + observed_df = ( + common_neighbourhood_distance_feature_logic( + self.sample_study_locus, + variant_index=self.sample_variant_index, + feature_name=feature_name, + distance_type=self.distance_type, + genomic_window=10, + ) + .withColumn(feature_name, f.round(f.col(feature_name), 2)) + .orderBy(f.col(feature_name).asc()) ) - observed_df = common_neighbourhood_distance_feature_logic( - self.sample_study_locus, - variant_index=another_sample_variant_index, - feature_name="distanceTssMinimum", - distance_type=self.distance_type, - agg_expr=f.min("weightedDistance"), - ).orderBy(f.col("distanceTssMinimum").asc()) expected_df = spark.createDataFrame( - ([1, "gene2", -47.5], [1, "gene1", 0.0]), - ["studyLocusId", "geneId", "distanceTssMinimum"], - ).orderBy(f.col("distanceTssMinimum").asc()) + ([1, "gene1", -0.48], [1, "gene2", 0.48]), + ["studyLocusId", "geneId", feature_name], + ).orderBy(feature_name) assert ( observed_df.collect() == expected_df.collect() ), "Output doesn't meet the expectation." @@ -401,7 +403,7 @@ def _setup(self: TestCommonDistanceFeatureLogic, spark: SparkSession) -> None: "posteriorProbability": 0.5, }, { - "variantId": "tag1", # this variant is closer to gene1 + "variantId": "tag1", "posteriorProbability": 0.5, }, ], @@ -444,6 +446,7 @@ def _setup(self: TestCommonDistanceFeatureLogic, spark: SparkSession) -> None: "T", [ {"distanceFromTss": 10, "targetId": "gene1"}, + {"distanceFromTss": 2, "targetId": "gene2"}, ], ), ( From f08e4324950284feb4595ebd755049eb409f71fd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Irene=20L=C3=B3pez?= Date: Fri, 27 Sep 2024 15:22:32 +0100 Subject: [PATCH 43/48] docs: update distance class names --- docs/python_api/datasets/l2g_features/distance.md | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/docs/python_api/datasets/l2g_features/distance.md b/docs/python_api/datasets/l2g_features/distance.md index 56e9069b3..e426b2952 100644 --- a/docs/python_api/datasets/l2g_features/distance.md +++ b/docs/python_api/datasets/l2g_features/distance.md @@ -4,13 +4,13 @@ title: From distance ## List of features -::: gentropy.dataset.l2g_features.distance.DistanceSentinelTssMinimumFeature +::: gentropy.dataset.l2g_features.distance.DistanceSentinelTssFeature +::: gentropy.dataset.l2g_features.distance.DistanceSentinelTssNeighbourhoodFeature ::: gentropy.dataset.l2g_features.distance.DistanceTssMeanFeature -::: gentropy.dataset.l2g_features.distance.DistanceSentinelFootprintMinimumFeature -::: gentropy.dataset.l2g_features.distance.DistanceFootprintMeanFeature -::: gentropy.dataset.l2g_features.distance.DistanceTssFootprintMinimumNeighbourhoodFeature ::: gentropy.dataset.l2g_features.distance.DistanceTssMeanNeighbourhoodFeature -::: gentropy.dataset.l2g_features.distance.DistanceSentinelFootprintMinimumNeighbourhoodFeature +::: gentropy.dataset.l2g_features.distance.DistanceSentinelFootprintFeature +::: gentropy.dataset.l2g_features.distance.DistanceSentinelFootprintNeighbourhoodFeature +::: gentropy.dataset.l2g_features.distance.DistanceFootprintMeanFeature ::: gentropy.dataset.l2g_features.distance.DistanceFootprintMeanNeighbourhoodFeature ## Common logic From 72159843e0558c841a462d2d2e0b02c5bd179a52 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Irene=20L=C3=B3pez?= Date: Mon, 30 Sep 2024 14:53:37 +0100 Subject: [PATCH 44/48] fix: add all variant index mandatory fields in tests --- tests/gentropy/dataset/test_l2g_feature.py | 20 ++++++++++++++++--- .../open_targets/test_l2g_gold_standard.py | 16 +++++++++++++-- 2 files changed, 31 insertions(+), 5 deletions(-) diff --git a/tests/gentropy/dataset/test_l2g_feature.py b/tests/gentropy/dataset/test_l2g_feature.py index 96ba5cee1..6ce676a72 100644 --- a/tests/gentropy/dataset/test_l2g_feature.py +++ b/tests/gentropy/dataset/test_l2g_feature.py @@ -8,6 +8,7 @@ import pytest from pyspark.sql.types import ( ArrayType, + BooleanType, IntegerType, LongType, StringType, @@ -428,6 +429,7 @@ def _setup(self: TestCommonDistanceFeatureLogic, spark: SparkSession) -> None: [ StructField("distanceFromTss", LongType(), True), StructField("targetId", StringType(), True), + StructField("isEnsemblCanonical", BooleanType(), True), ] ) ), @@ -445,8 +447,16 @@ def _setup(self: TestCommonDistanceFeatureLogic, spark: SparkSession) -> None: "A", "T", [ - {"distanceFromTss": 10, "targetId": "gene1"}, - {"distanceFromTss": 2, "targetId": "gene2"}, + { + "distanceFromTss": 10, + "targetId": "gene1", + "isEnsemblCanonical": True, + }, + { + "distanceFromTss": 2, + "targetId": "gene2", + "isEnsemblCanonical": True, + }, ], ), ( @@ -456,7 +466,11 @@ def _setup(self: TestCommonDistanceFeatureLogic, spark: SparkSession) -> None: "A", "T", [ - {"distanceFromTss": 5, "targetId": "gene1"}, + { + "distanceFromTss": 5, + "targetId": "gene1", + "isEnsemblCanonical": True, + }, ], ), ], diff --git a/tests/gentropy/datasource/open_targets/test_l2g_gold_standard.py b/tests/gentropy/datasource/open_targets/test_l2g_gold_standard.py index 54953d7a5..aa36359ca 100644 --- a/tests/gentropy/datasource/open_targets/test_l2g_gold_standard.py +++ b/tests/gentropy/datasource/open_targets/test_l2g_gold_standard.py @@ -8,6 +8,7 @@ from pyspark.sql import DataFrame from pyspark.sql.types import ( ArrayType, + BooleanType, IntegerType, LongType, StringType, @@ -98,8 +99,16 @@ def _setup(self: TestExpandGoldStandardWithNegatives, spark: SparkSession) -> No "A", "T", [ - {"distanceFromTss": 5, "targetId": "gene1"}, - {"distanceFromTss": 10, "targetId": "gene3"}, + { + "distanceFromTss": 5, + "targetId": "gene1", + "isEnsemblCanonical": True, + }, + { + "distanceFromTss": 10, + "targetId": "gene3", + "isEnsemblCanonical": True, + }, ], ), ], @@ -117,6 +126,9 @@ def _setup(self: TestExpandGoldStandardWithNegatives, spark: SparkSession) -> No [ StructField("distanceFromTss", LongType(), True), StructField("targetId", StringType(), True), + StructField( + "isEnsemblCanonical", BooleanType(), True + ), ] ) ), From 72ea515fb3ea9cb07c448072be2449f4ced0dab3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Irene=20L=C3=B3pez?= Date: Mon, 30 Sep 2024 14:54:31 +0100 Subject: [PATCH 45/48] fix(schema_validator): remove extra `[]` from parent prefix --- src/gentropy/common/schemas.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gentropy/common/schemas.py b/src/gentropy/common/schemas.py index 624e3e0e1..b9770c79c 100644 --- a/src/gentropy/common/schemas.py +++ b/src/gentropy/common/schemas.py @@ -94,7 +94,7 @@ def compare_array_schemas( schema_issues = compare_struct_schemas( observed_schema.elementType, expected_schema.elementType, - f"{parent_field_name}[].", + f"{parent_field_name}.", schema_issues, ) From 087b89752bb726ee743658ff01d37291c31522ec Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Irene=20L=C3=B3pez?= Date: Mon, 30 Sep 2024 17:58:57 +0100 Subject: [PATCH 46/48] fix: convert studylocusid to string in tests --- tests/gentropy/dataset/test_l2g_feature.py | 50 +++++++++++++--------- 1 file changed, 29 insertions(+), 21 deletions(-) diff --git a/tests/gentropy/dataset/test_l2g_feature.py b/tests/gentropy/dataset/test_l2g_feature.py index 6ce676a72..18d8a4066 100644 --- a/tests/gentropy/dataset/test_l2g_feature.py +++ b/tests/gentropy/dataset/test_l2g_feature.py @@ -137,12 +137,12 @@ def test__common_colocalisation_feature_logic( expected_df = spark.createDataFrame( [ { - "studyLocusId": 1, + "studyLocusId": "1", "geneId": "gene1", "eQtlColocH4Maximum": 0.81, }, { - "studyLocusId": 1, + "studyLocusId": "1", "geneId": "gene2", "eQtlColocH4Maximum": 0.9, }, @@ -170,12 +170,12 @@ def test__common_neighbourhood_colocalisation_feature_logic( expected_df = spark.createDataFrame( [ { - "studyLocusId": 1, + "studyLocusId": "1", "geneId": "gene1", "eQtlColocH4MaximumNeighbourhood": 0.08999999999999997, }, { - "studyLocusId": 1, + "studyLocusId": "1", "geneId": "gene2", "eQtlColocH4MaximumNeighbourhood": 0.0, }, @@ -196,7 +196,7 @@ def _setup(self: TestCommonColocalisationFeatureLogic, spark: SparkSession) -> N _df=spark.createDataFrame( [ { - "studyLocusId": 1, + "studyLocusId": "1", "variantId": "lead1", "studyId": "study1", # this is a GWAS "chromosome": "1", @@ -209,8 +209,8 @@ def _setup(self: TestCommonColocalisationFeatureLogic, spark: SparkSession) -> N _df=spark.createDataFrame( [ { - "leftStudyLocusId": 1, - "rightStudyLocusId": 2, + "leftStudyLocusId": "1", + "rightStudyLocusId": "2", "chromosome": "1", "colocalisationMethod": "COLOC", "numberColocalisingVariants": 1, @@ -218,8 +218,8 @@ def _setup(self: TestCommonColocalisationFeatureLogic, spark: SparkSession) -> N "rightStudyType": "eqtl", }, { - "leftStudyLocusId": 1, - "rightStudyLocusId": 3, # qtl linked to the same gene as studyLocusId 2 with a lower score + "leftStudyLocusId": "1", + "rightStudyLocusId": "3", # qtl linked to the same gene as studyLocusId 2 with a lower score "chromosome": "1", "colocalisationMethod": "COLOC", "numberColocalisingVariants": 1, @@ -227,8 +227,8 @@ def _setup(self: TestCommonColocalisationFeatureLogic, spark: SparkSession) -> N "rightStudyType": "eqtl", }, { - "leftStudyLocusId": 1, - "rightStudyLocusId": 4, # qtl linked to a diff gene and with the highest score + "leftStudyLocusId": "1", + "rightStudyLocusId": "4", # qtl linked to a diff gene and with the highest score "chromosome": "1", "colocalisationMethod": "COLOC", "numberColocalisingVariants": 1, @@ -244,25 +244,25 @@ def _setup(self: TestCommonColocalisationFeatureLogic, spark: SparkSession) -> N _df=spark.createDataFrame( [ { - "studyLocusId": 1, + "studyLocusId": "1", "variantId": "lead1", "studyId": "study1", # this is a GWAS "chromosome": "1", }, { - "studyLocusId": 2, + "studyLocusId": "2", "variantId": "lead1", "studyId": "study2", # this is a QTL (same gee) "chromosome": "1", }, { - "studyLocusId": 3, + "studyLocusId": "3", "variantId": "lead1", "studyId": "study3", # this is another QTL (same gene) "chromosome": "1", }, { - "studyLocusId": 4, + "studyLocusId": "4", "variantId": "lead1", "studyId": "study4", # this is another QTL (diff gene) "chromosome": "1", @@ -317,15 +317,23 @@ class TestCommonDistanceFeatureLogic: ( "distanceSentinelTss", [ - {"studyLocusId": 1, "geneId": "gene1", "distanceSentinelTss": 0.0}, - {"studyLocusId": 1, "geneId": "gene2", "distanceSentinelTss": 0.95}, + { + "studyLocusId": "1", + "geneId": "gene1", + "distanceSentinelTss": 0.0, + }, + { + "studyLocusId": "1", + "geneId": "gene2", + "distanceSentinelTss": 0.95, + }, ], ), ( "distanceTssMean", [ - {"studyLocusId": 1, "geneId": "gene1", "distanceTssMean": 0.09}, - {"studyLocusId": 1, "geneId": "gene2", "distanceTssMean": 0.65}, + {"studyLocusId": "1", "geneId": "gene1", "distanceTssMean": 0.09}, + {"studyLocusId": "1", "geneId": "gene2", "distanceTssMean": 0.65}, ], ), ], @@ -380,7 +388,7 @@ def test_common_neighbourhood_colocalisation_feature_logic( .orderBy(f.col(feature_name).asc()) ) expected_df = spark.createDataFrame( - ([1, "gene1", -0.48], [1, "gene2", 0.48]), + (["1", "gene1", -0.48], ["1", "gene2", 0.48]), ["studyLocusId", "geneId", feature_name], ).orderBy(feature_name) assert ( @@ -395,7 +403,7 @@ def _setup(self: TestCommonDistanceFeatureLogic, spark: SparkSession) -> None: _df=spark.createDataFrame( [ { - "studyLocusId": 1, + "studyLocusId": "1", "variantId": "lead1", "studyId": "study1", "locus": [ From f52b6e29102d23b1691858c06ab89e6f4f1050d1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Irene=20L=C3=B3pez?= Date: Mon, 30 Sep 2024 17:59:34 +0100 Subject: [PATCH 47/48] revert: revert 72ea515fb3ea9cb07c448072be2449f4ced0dab3 (it was ok) --- src/gentropy/common/schemas.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gentropy/common/schemas.py b/src/gentropy/common/schemas.py index b9770c79c..624e3e0e1 100644 --- a/src/gentropy/common/schemas.py +++ b/src/gentropy/common/schemas.py @@ -94,7 +94,7 @@ def compare_array_schemas( schema_issues = compare_struct_schemas( observed_schema.elementType, expected_schema.elementType, - f"{parent_field_name}.", + f"{parent_field_name}[].", schema_issues, ) From fea5fa84301149bd0319246a992e35ed99d894f4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Irene=20L=C3=B3pez?= Date: Mon, 30 Sep 2024 18:09:38 +0100 Subject: [PATCH 48/48] fix: adapt test --- tests/gentropy/conftest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/gentropy/conftest.py b/tests/gentropy/conftest.py index 2f4b9e1a2..a70c1a87d 100644 --- a/tests/gentropy/conftest.py +++ b/tests/gentropy/conftest.py @@ -607,7 +607,7 @@ def mock_l2g_feature_matrix(spark: SparkSession) -> L2GFeatureMatrix: ("1", "gene1", 100.0, None), ("2", "gene2", 1000.0, 0.0), ], - "studyLocusId STRING, geneId STRING, distanceTssMean FLOAT, distanceTssMinimum FLOAT", + "studyLocusId STRING, geneId STRING, distanceTssMean FLOAT, distanceSentinelTssMinimum FLOAT", ), with_gold_standard=False, )