diff --git a/docs/howto/command_line/run_step_in_cli.md b/docs/howto/command_line/run_step_in_cli.md index ac7d55ff9..b935d84fb 100644 --- a/docs/howto/command_line/run_step_in_cli.md +++ b/docs/howto/command_line/run_step_in_cli.md @@ -24,7 +24,6 @@ Available options: ukbiobank variant_annotation variant_index - variant_to_gene Set the environment variable HYDRA_FULL_ERROR=1 for a complete stack trace. ``` diff --git a/docs/python_api/datasets/l2g_feature.md b/docs/python_api/datasets/l2g_feature.md deleted file mode 100644 index bdab67e7c..000000000 --- a/docs/python_api/datasets/l2g_feature.md +++ /dev/null @@ -1,29 +0,0 @@ ---- -title: L2G Feature ---- - -## Abstract Class - -::: gentropy.dataset.l2g_feature.L2GFeature - -## Feature Classes - -### Derived from colocalisation - -::: gentropy.dataset.l2g_feature.EQtlColocClppMaximumFeature -::: gentropy.dataset.l2g_feature.PQtlColocClppMaximumFeature -::: gentropy.dataset.l2g_feature.SQtlColocClppMaximumFeature -::: gentropy.dataset.l2g_feature.TuQtlColocClppMaximumFeature -::: gentropy.dataset.l2g_feature.EQtlColocH4MaximumFeature -::: gentropy.dataset.l2g_feature.PQtlColocH4MaximumFeature -::: gentropy.dataset.l2g_feature.SQtlColocH4MaximumFeature -::: gentropy.dataset.l2g_feature.TuQtlColocH4MaximumFeature - -### Derived from distance - -::: gentropy.dataset.l2g_feature.DistanceTssMinimumFeature -::: gentropy.dataset.l2g_feature.DistanceTssMeanFeature - -## Schema - ---8<-- "assets/schemas/l2g_feature.md" diff --git a/docs/python_api/datasets/l2g_features/_l2g_feature.md b/docs/python_api/datasets/l2g_features/_l2g_feature.md new file mode 100644 index 000000000..b2f6f8187 --- /dev/null +++ b/docs/python_api/datasets/l2g_features/_l2g_feature.md @@ -0,0 +1,11 @@ +--- +title: L2G Feature +--- + +## Abstract Class + +::: gentropy.dataset.l2g_features.l2g_feature.L2GFeature + +## Schema + +--8<-- "assets/schemas/l2g_feature.md" diff --git a/docs/python_api/datasets/l2g_features/colocalisation.md b/docs/python_api/datasets/l2g_features/colocalisation.md new file mode 100644 index 000000000..a3928c4ab --- /dev/null +++ b/docs/python_api/datasets/l2g_features/colocalisation.md @@ -0,0 +1,27 @@ +--- +title: From colocalisation +--- + +## List of features + +::: gentropy.dataset.l2g_features.colocalisation.EQtlColocClppMaximumFeature +::: gentropy.dataset.l2g_features.colocalisation.PQtlColocClppMaximumFeature +::: gentropy.dataset.l2g_features.colocalisation.SQtlColocClppMaximumFeature +::: gentropy.dataset.l2g_features.colocalisation.TuQtlColocClppMaximumFeature +::: gentropy.dataset.l2g_features.colocalisation.EQtlColocH4MaximumFeature +::: gentropy.dataset.l2g_features.colocalisation.PQtlColocH4MaximumFeature +::: gentropy.dataset.l2g_features.colocalisation.SQtlColocH4MaximumFeature +::: gentropy.dataset.l2g_features.colocalisation.TuQtlColocH4MaximumFeature +::: gentropy.dataset.l2g_features.colocalisation.EQtlColocClppMaximumNeighbourhoodFeature +::: gentropy.dataset.l2g_features.colocalisation.PQtlColocClppMaximumNeighbourhoodFeature +::: gentropy.dataset.l2g_features.colocalisation.SQtlColocClppMaximumNeighbourhoodFeature +::: gentropy.dataset.l2g_features.colocalisation.TuQtlColocClppMaximumNeighbourhoodFeature +::: gentropy.dataset.l2g_features.colocalisation.EQtlColocH4MaximumNeighbourhoodFeature +::: gentropy.dataset.l2g_features.colocalisation.PQtlColocH4MaximumNeighbourhoodFeature +::: gentropy.dataset.l2g_features.colocalisation.SQtlColocH4MaximumNeighbourhoodFeature +::: gentropy.dataset.l2g_features.colocalisation.TuQtlColocH4MaximumNeighbourhoodFeature + +## Common logic + +::: gentropy.dataset.l2g_features.colocalisation.common_colocalisation_feature_logic +::: gentropy.dataset.l2g_features.colocalisation.common_neighbourhood_colocalisation_feature_logic diff --git a/docs/python_api/datasets/l2g_features/distance.md b/docs/python_api/datasets/l2g_features/distance.md new file mode 100644 index 000000000..e426b2952 --- /dev/null +++ b/docs/python_api/datasets/l2g_features/distance.md @@ -0,0 +1,19 @@ +--- +title: From distance +--- + +## List of features + +::: gentropy.dataset.l2g_features.distance.DistanceSentinelTssFeature +::: gentropy.dataset.l2g_features.distance.DistanceSentinelTssNeighbourhoodFeature +::: gentropy.dataset.l2g_features.distance.DistanceTssMeanFeature +::: gentropy.dataset.l2g_features.distance.DistanceTssMeanNeighbourhoodFeature +::: gentropy.dataset.l2g_features.distance.DistanceSentinelFootprintFeature +::: gentropy.dataset.l2g_features.distance.DistanceSentinelFootprintNeighbourhoodFeature +::: gentropy.dataset.l2g_features.distance.DistanceFootprintMeanFeature +::: gentropy.dataset.l2g_features.distance.DistanceFootprintMeanNeighbourhoodFeature + +## Common logic + +::: gentropy.dataset.l2g_features.distance.common_distance_feature_logic +::: gentropy.dataset.l2g_features.distance.common_neighbourhood_distance_feature_logic diff --git a/docs/python_api/datasets/variant_to_gene.md b/docs/python_api/datasets/variant_to_gene.md deleted file mode 100644 index 2af67df92..000000000 --- a/docs/python_api/datasets/variant_to_gene.md +++ /dev/null @@ -1,9 +0,0 @@ ---- -title: Variant-to-gene ---- - -::: gentropy.dataset.v2g.V2G - -## Schema - ---8<-- "assets/schemas/v2g.md" diff --git a/docs/python_api/methods/l2g/_l2g.md b/docs/python_api/methods/l2g/_l2g.md index fca3ba79d..bbd7dad66 100644 --- a/docs/python_api/methods/l2g/_l2g.md +++ b/docs/python_api/methods/l2g/_l2g.md @@ -9,13 +9,10 @@ The **“locus-to-gene” (L2G)** model derives features to prioritize likely ca - **Chromatin Interaction:** (e.g., promoter-capture Hi-C) - **Variant Pathogenicity:** (from VEP) -The L2G model is distinct from the variant-to-gene (V2G) pipeline in that it: - -- Uses a machine-learning model to learn the weights of each evidence source based on a gold standard of previously identified causal genes. -- Relies upon fine-mapping and colocalization data. - Some of the predictive features weight variant-to-gene (or genomic region-to-gene) evidence based on the posterior probability that the variant is causal, determined through fine-mapping of the GWAS association. +For a more detailed description of how each feature is computed, see [the L2G Feature documentation](../../datasets/l2g_features/_l2g_feature.md). + Details of the L2G model are provided in our Nature Genetics publication (ref - [Nature Genetics Publication](https://www.nature.com/articles/s41588-021-00945-5)): - **Title:** An open approach to systematically prioritize causal variants and genes at all published human GWAS trait-associated loci. diff --git a/docs/python_api/steps/variant_to_gene_step.md b/docs/python_api/steps/variant_to_gene_step.md deleted file mode 100644 index db1c1fd20..000000000 --- a/docs/python_api/steps/variant_to_gene_step.md +++ /dev/null @@ -1,5 +0,0 @@ ---- -title: variant_to_gene ---- - -::: gentropy.variant_to_gene.V2GStep diff --git a/notebooks/Release_QC_metrics.ipynb b/notebooks/Release_QC_metrics.ipynb index aa0924711..4eb27015b 100644 --- a/notebooks/Release_QC_metrics.ipynb +++ b/notebooks/Release_QC_metrics.ipynb @@ -13,21 +13,17 @@ "1. Import necessary modules and set up the release path and version.\n", "2. Load and analyze the variant index data:\n", " - Count the number of unique variants.\n", - "3. Load and analyze the variant-to-gene (v2g) data:\n", - " - Count the number of unique variants and total variant-to-gene assignments.\n", - " - Count the number of v2g assignments where the score is > 0.8.\n", - " - Plot a histogram/density plot for the \"score\" column.\n", - "4. Load and analyze the study index data for different data sources (FinnGen, GWASCat, eQTLcat):\n", + "3. Load and analyze the study index data for different data sources (FinnGen, GWASCat, eQTLcat):\n", " - Count the number of unique studies for each data source.\n", - "5. Analyze the credible sets for each datasource (Finngen, gwascat, eqtlcat):\n", + "4. Analyze the credible sets for each datasource (Finngen, gwascat, eqtlcat):\n", " - Analyze the credible sets:\n", " - Count the number of unique credible sets and unique study IDs.\n", " - Plot a scatter plot of the credible set size vs. the top posterior probability.\n", " - Count the number of credible sets with a top SNP posterior probability > 0.9..\n", - "6. Analyze colocalization data:\n", + "5. Analyze colocalization data:\n", " - Count the total number of colocalizations and the number with clpp > 0.8.\n", " - Calculate the average number of overlaps per credible set.\n", - "7. Analyze locus-to-gene (L2G) predictions:\n", + "6. Analyze locus-to-gene (L2G) predictions:\n", " - Load the locus-to-gene predictions data.\n", " - How many Studylocus contains a \"good\" l2g prediction? (l2g_score > 0.5)\n", " - How does l2g perform based on different datasource inputs? (impossible to tell)\n", @@ -126,79 +122,6 @@ "#variant_index.filter(variant_index[\"alleleFrequencies.populationName\"] > 0.05).show(10, False)" ] }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "\n", - "#### 3. Load and analyze the variant-to-gene (v2g) data:\n", - " - Count the number of unique variants and total variant-to-gene assignments.\n", - " - Count the number of v2g assignments where the score is > 0.8." - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - " \r" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Unique variants in v2g release: 5090991 , total variant to gene assignments: 105771851 , number of v2g assignments where score > 0.8: 23176515 ( 4.552 %)\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - " \r" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Summary of v2g_score: Mean: 0.5909395615801637 L.quart: 0.29 Median: 0.62 U.quart: 0.94\n" - ] - } - ], - "source": [ - "#v2g_path='gs://genetics_etl_python_playground/releases/24.03/variant_to_gene'\n", - "v2g_path=f\"{release_path}/{release_ver}/variant_to_gene\"\n", - "v2g=session.spark.read.parquet(v2g_path, recursiveFileLookup=True)\n", - "\n", - "#How many variants?\n", - "sample_size_quartiles = v2g.stat.approxQuantile(\"score\", [0.25, 0.5, 0.75], 0.01)\n", - "#v2g.select().toPandas().plot.hist()\n", - "#v2g.show()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - " - Plot a histogram/density plot for the \"score\" column." - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [], - "source": [ - "#The histogram/density plot for “score”\n", - "# Out of mem error:\n", - "#v2g.select(f.col(\"score\")).toPandas().plot.hist(bins=10, alpha=0.5, label=\"v2g scores\")" - ] - }, { "cell_type": "markdown", "metadata": {}, diff --git a/src/gentropy/assets/schemas/v2g.json b/src/gentropy/assets/schemas/v2g.json deleted file mode 100644 index afbe401dd..000000000 --- a/src/gentropy/assets/schemas/v2g.json +++ /dev/null @@ -1,77 +0,0 @@ -{ - "type": "struct", - "fields": [ - { - "name": "geneId", - "type": "string", - "nullable": false, - "metadata": {} - }, - { - "name": "variantId", - "type": "string", - "nullable": false, - "metadata": {} - }, - { - "name": "distance", - "type": "long", - "nullable": true, - "metadata": {} - }, - { - "name": "chromosome", - "type": "string", - "nullable": false, - "metadata": {} - }, - { - "name": "datatypeId", - "type": "string", - "nullable": false, - "metadata": {} - }, - { - "name": "datasourceId", - "type": "string", - "nullable": false, - "metadata": {} - }, - { - "name": "score", - "type": "double", - "nullable": true, - "metadata": {} - }, - { - "name": "resourceScore", - "type": "double", - "nullable": true, - "metadata": {} - }, - { - "name": "pmid", - "type": "string", - "nullable": true, - "metadata": {} - }, - { - "name": "biofeature", - "type": "string", - "nullable": true, - "metadata": {} - }, - { - "name": "variantFunctionalConsequenceId", - "type": "string", - "nullable": true, - "metadata": {} - }, - { - "name": "isHighQualityPlof", - "type": "boolean", - "nullable": true, - "metadata": {} - } - ] -} diff --git a/src/gentropy/common/spark_helpers.py b/src/gentropy/common/spark_helpers.py index 4d24212a2..8f60956e7 100644 --- a/src/gentropy/common/spark_helpers.py +++ b/src/gentropy/common/spark_helpers.py @@ -6,7 +6,7 @@ import sys from functools import reduce, wraps from itertools import chain -from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, Optional, TypeVar +from typing import TYPE_CHECKING, Any, Callable, Iterable, Optional, TypeVar import pyspark.sql.functions as f import pyspark.sql.types as t @@ -447,14 +447,14 @@ def order_array_of_structs_by_two_fields( ) -def map_column_by_dictionary(col: Column, mapping_dict: Dict[str, str]) -> Column: +def map_column_by_dictionary(col: Column, mapping_dict: dict[str, str]) -> Column: """Map column values to dictionary values by key. Missing consequence label will be converted to None, unmapped consequences will be mapped as None. Args: col (Column): Column containing labels to map. - mapping_dict (Dict[str, str]): Dictionary with mapping key/value pairs. + mapping_dict (dict[str, str]): Dictionary with mapping key/value pairs. Returns: Column: Column with mapped values. diff --git a/src/gentropy/config.py b/src/gentropy/config.py index 0a1f9438a..6f94cc9ed 100644 --- a/src/gentropy/config.py +++ b/src/gentropy/config.py @@ -2,7 +2,7 @@ import os from dataclasses import dataclass, field -from typing import Any, Dict, List +from typing import Any, List from hail import __file__ as hail_location from hydra.core.config_store import ConfigStore @@ -235,7 +235,7 @@ class LocusToGeneConfig(StepConfig): run_mode: str = MISSING predictions_path: str = MISSING credible_set_path: str = MISSING - variant_gene_path: str = MISSING + variant_index_path: str = MISSING colocalisation_path: str = MISSING study_index_path: str = MISSING model_path: str | None = None @@ -254,6 +254,16 @@ class LocusToGeneConfig(StepConfig): "pQtlColocH4Maximum", "sQtlColocH4Maximum", "tuQtlColocH4Maximum", + # distance to gene footprint + "distanceSentinelFootprint", + "distanceSentinelFootprintNeighbourhood", + "distanceFootprintMean", + "distanceFootprintMeanNeighbourhood", + # distance to gene tss + "distanceTssMean", + "distanceTssMeanNeighbourhood", + "distanceSentinelTss", + "distanceSentinelTssNeighbourhood", ] ) hyperparameters: dict[str, Any] = field( @@ -357,38 +367,6 @@ class ConvertToVcfStepConfig(StepConfig): _target_: str = "gentropy.variant_index.ConvertToVcfStep" -@dataclass -class VariantToGeneConfig(StepConfig): - """V2G step configuration.""" - - variant_index_path: str = MISSING - gene_index_path: str = MISSING - vep_consequences_path: str = MISSING - liftover_chain_file_path: str = MISSING - liftover_max_length_difference: int = 100 - max_distance: int = 500_000 - approved_biotypes: List[str] = field( - default_factory=lambda: [ - "protein_coding", - "3prime_overlapping_ncRNA", - "antisense", - "bidirectional_promoter_lncRNA", - "IG_C_gene", - "IG_D_gene", - "IG_J_gene", - "IG_V_gene", - "lincRNA", - "macro_lncRNA", - "non_coding", - "sense_intronic", - "sense_overlapping", - ] - ) - interval_sources: Dict[str, str] = field(default_factory=dict) - v2g_path: str = MISSING - _target_: str = "gentropy.variant_to_gene.V2GStep" - - @dataclass class LocusBreakerClumpingConfig(StepConfig): """Locus breaker clumping step configuration.""" @@ -565,7 +543,6 @@ def register_config() -> None: cs.store(group="step", name="ukb_ppp_eur_sumstat_preprocess", node=UkbPppEurConfig) cs.store(group="step", name="variant_index", node=VariantIndexConfig) cs.store(group="step", name="variant_to_vcf", node=ConvertToVcfStepConfig) - cs.store(group="step", name="variant_to_gene", node=VariantToGeneConfig) cs.store( group="step", name="window_based_clumping", node=WindowBasedClumpingStepConfig ) diff --git a/src/gentropy/dataset/colocalisation.py b/src/gentropy/dataset/colocalisation.py index 9e9035488..c85209462 100644 --- a/src/gentropy/dataset/colocalisation.py +++ b/src/gentropy/dataset/colocalisation.py @@ -18,6 +18,7 @@ from pyspark.sql.types import StructType from gentropy.dataset.study_index import StudyIndex + from gentropy.dataset.study_locus import StudyLocus from functools import reduce diff --git a/src/gentropy/dataset/intervals.py b/src/gentropy/dataset/intervals.py index c3b9136c9..37158810b 100644 --- a/src/gentropy/dataset/intervals.py +++ b/src/gentropy/dataset/intervals.py @@ -1,22 +1,19 @@ """Interval dataset.""" + from __future__ import annotations from dataclasses import dataclass from typing import TYPE_CHECKING -import pyspark.sql.functions as f - from gentropy.common.Liftover import LiftOverSpark from gentropy.common.schemas import parse_spark_schema from gentropy.dataset.dataset import Dataset from gentropy.dataset.gene_index import GeneIndex -from gentropy.dataset.v2g import V2G if TYPE_CHECKING: from pyspark.sql import SparkSession from pyspark.sql.types import StructType - from gentropy.dataset.variant_index import VariantIndex @dataclass @@ -74,32 +71,3 @@ def from_source( source_class = source_to_class[source_name] data = source_class.read(spark, source_path) # type: ignore return source_class.parse(data, gene_index, lift) # type: ignore - - def v2g(self: Intervals, variant_index: VariantIndex) -> V2G: - """Convert intervals into V2G by intersecting with a variant index. - - Args: - variant_index (VariantIndex): Variant index dataset - - Returns: - V2G: Variant-to-gene evidence dataset - """ - return V2G( - _df=( - self.df.alias("interval") - .join( - variant_index.df.selectExpr( - "chromosome as vi_chromosome", "variantId", "position" - ).alias("vi"), - on=[ - f.col("vi.vi_chromosome") == f.col("interval.chromosome"), - f.col("vi.position").between( - f.col("interval.start"), f.col("interval.end") - ), - ], - how="inner", - ) - .drop("start", "end", "vi_chromosome", "position") - ), - _schema=V2G.get_schema(), - ) diff --git a/src/gentropy/dataset/l2g_feature.py b/src/gentropy/dataset/l2g_feature.py deleted file mode 100644 index 319570cfd..000000000 --- a/src/gentropy/dataset/l2g_feature.py +++ /dev/null @@ -1,506 +0,0 @@ -"""L2G Feature Dataset with a collection of methods that extract features from the gentropy datasets to be fed in L2G.""" - -from __future__ import annotations - -from abc import ABC, abstractmethod -from dataclasses import dataclass -from typing import TYPE_CHECKING, Any - -import pyspark.sql.functions as f - -from gentropy.common.schemas import parse_spark_schema -from gentropy.common.spark_helpers import convert_from_wide_to_long -from gentropy.dataset.colocalisation import Colocalisation -from gentropy.dataset.dataset import Dataset -from gentropy.dataset.l2g_gold_standard import L2GGoldStandard -from gentropy.dataset.study_index import StudyIndex -from gentropy.dataset.study_locus import StudyLocus -from gentropy.dataset.v2g import V2G - -if TYPE_CHECKING: - from pyspark.sql import DataFrame - from pyspark.sql.types import StructType - - -@dataclass -class L2GFeature(Dataset, ABC): - """Locus-to-gene feature dataset.""" - - def __post_init__( - self: L2GFeature, - feature_dependency_type: Any = None, - credible_set: StudyLocus | None = None, - ) -> None: - """Initializes a L2GFeature dataset. Any child class of L2GFeature must implement the `compute` method. - - Args: - feature_dependency_type (Any): The dependency that the L2GFeature dataset depends on. Defaults to None. - credible_set (StudyLocus | None): The credible set that the L2GFeature dataset is based on. Defaults to None. - """ - super().__post_init__() - self.feature_dependency_type = feature_dependency_type - self.credible_set = credible_set - - @classmethod - def get_schema(cls: type[L2GFeature]) -> StructType: - """Provides the schema for the L2GFeature dataset. - - Returns: - StructType: Schema for the L2GFeature dataset - """ - return parse_spark_schema("l2g_feature.json") - - @classmethod - @abstractmethod - def compute( - cls: type[L2GFeature], - study_loci_to_annotate: StudyLocus | L2GGoldStandard, - feature_dependency: Any, - ) -> L2GFeature: - """Computes the L2GFeature dataset. - - Args: - study_loci_to_annotate (StudyLocus | L2GGoldStandard): The dataset containing study loci that will be used for annotation - feature_dependency (Any): The dependency that the L2GFeature class needs to compute the feature - Returns: - L2GFeature: a L2GFeature dataset - - Raises: - NotImplementedError: This method must be implemented in the child classes - """ - raise NotImplementedError("Must be implemented in the child classes") - - -def _common_colocalisation_feature_logic( - study_loci_to_annotate: StudyLocus | L2GGoldStandard, - colocalisation_method: str, - colocalisation_metric: str, - feature_name: str, - qtl_type: str, - *, - colocalisation: Colocalisation, - study_index: StudyIndex, - study_locus: StudyLocus, -) -> DataFrame: - """Wrapper to call the logic that creates a type of colocalisation features. - - Args: - study_loci_to_annotate (StudyLocus | L2GGoldStandard): The dataset containing study loci that will be used for annotation - colocalisation_method (str): The colocalisation method to filter the data by - colocalisation_metric (str): The colocalisation metric to use - feature_name (str): The name of the feature to create - qtl_type (str): The type of QTL to filter the data by - colocalisation (Colocalisation): Dataset with the colocalisation results - study_index (StudyIndex): Study index to fetch study type and gene - study_locus (StudyLocus): Study locus to traverse between colocalisation and study index - - Returns: - DataFrame: Feature annotation in long format with the columns: studyLocusId, geneId, featureName, featureValue - """ - joining_cols = ( - ["studyLocusId", "geneId"] - if isinstance(study_loci_to_annotate, L2GGoldStandard) - else ["studyLocusId"] - ) - return convert_from_wide_to_long( - study_loci_to_annotate.df.join( - colocalisation.extract_maximum_coloc_probability_per_region_and_gene( - study_locus, - study_index, - filter_by_colocalisation_method=colocalisation_method, - filter_by_qtl=qtl_type, - ), - on=joining_cols, - ) - .selectExpr( - "studyLocusId", - "geneId", - f"{colocalisation_metric} as {feature_name}", - ) - .distinct(), - id_vars=("studyLocusId", "geneId"), - var_name="featureName", - value_name="featureValue", - ) - - -class EQtlColocClppMaximumFeature(L2GFeature): - """Max CLPP for each (study, locus, gene) aggregating over all eQTLs.""" - - feature_dependency_type = [Colocalisation, StudyIndex, StudyLocus] - feature_name = "eQtlColocClppMaximum" - - @classmethod - def compute( - cls: type[EQtlColocClppMaximumFeature], - study_loci_to_annotate: StudyLocus | L2GGoldStandard, - feature_dependency: dict[str, Any], - ) -> EQtlColocClppMaximumFeature: - """Computes the feature. - - Args: - study_loci_to_annotate (StudyLocus | L2GGoldStandard): The dataset containing study loci that will be used for annotation - feature_dependency (dict[str, Any]): Dictionary with the dependencies required. They are passed as keyword arguments. - - Returns: - EQtlColocClppMaximumFeature: Feature dataset - """ - colocalisation_method = "ECaviar" - colocalisation_metric = "clpp" - qtl_type = "eqtl" - - return cls( - _df=_common_colocalisation_feature_logic( - study_loci_to_annotate, - colocalisation_method, - colocalisation_metric, - cls.feature_name, - qtl_type, - **feature_dependency, - ), - _schema=cls.get_schema(), - ) - - -class PQtlColocClppMaximumFeature(L2GFeature): - """Max CLPP for each (study, locus, gene) aggregating over all pQTLs.""" - - feature_dependency_type = [Colocalisation, StudyIndex, StudyLocus] - feature_name = "pQtlColocClppMaximum" - - @classmethod - def compute( - cls: type[PQtlColocClppMaximumFeature], - study_loci_to_annotate: StudyLocus | L2GGoldStandard, - feature_dependency: dict[str, Any], - ) -> PQtlColocClppMaximumFeature: - """Computes the feature. - - Args: - study_loci_to_annotate (StudyLocus | L2GGoldStandard): The dataset containing study loci that will be used for annotation - feature_dependency (dict[str, Any]): Dataset with the colocalisation results - - Returns: - PQtlColocClppMaximumFeature: Feature dataset - """ - colocalisation_method = "ECaviar" - colocalisation_metric = "clpp" - qtl_type = "pqtl" - return cls( - _df=_common_colocalisation_feature_logic( - study_loci_to_annotate, - colocalisation_method, - colocalisation_metric, - cls.feature_name, - qtl_type, - **feature_dependency, - ), - _schema=cls.get_schema(), - ) - - -class SQtlColocClppMaximumFeature(L2GFeature): - """Max CLPP for each (study, locus, gene) aggregating over all sQTLs.""" - - feature_dependency_type = [Colocalisation, StudyIndex, StudyLocus] - feature_name = "sQtlColocClppMaximum" - - @classmethod - def compute( - cls: type[SQtlColocClppMaximumFeature], - study_loci_to_annotate: StudyLocus | L2GGoldStandard, - feature_dependency: dict[str, Any], - ) -> SQtlColocClppMaximumFeature: - """Computes the feature. - - Args: - study_loci_to_annotate (StudyLocus | L2GGoldStandard): The dataset containing study loci that will be used for annotation - feature_dependency (dict[str, Any]): Dataset with the colocalisation results - - Returns: - SQtlColocClppMaximumFeature: Feature dataset - """ - colocalisation_method = "ECaviar" - colocalisation_metric = "clpp" - qtl_type = "sqtl" - return cls( - _df=_common_colocalisation_feature_logic( - study_loci_to_annotate, - colocalisation_method, - colocalisation_metric, - cls.feature_name, - qtl_type, - **feature_dependency, - ), - _schema=cls.get_schema(), - ) - - -class TuQtlColocClppMaximumFeature(L2GFeature): - """Max CLPP for each (study, locus, gene) aggregating over all tuQTLs.""" - - feature_dependency_type = [Colocalisation, StudyIndex, StudyLocus] - feature_name = "tuQtlColocClppMaximum" - - @classmethod - def compute( - cls: type[TuQtlColocClppMaximumFeature], - study_loci_to_annotate: StudyLocus | L2GGoldStandard, - feature_dependency: dict[str, Any], - ) -> TuQtlColocClppMaximumFeature: - """Computes the feature. - - Args: - study_loci_to_annotate (StudyLocus | L2GGoldStandard): The dataset containing study loci that will be used for annotation - feature_dependency (dict[str, Any]): Dataset with the colocalisation results - - Returns: - TuQtlColocClppMaximumFeature: Feature dataset - """ - colocalisation_method = "ECaviar" - colocalisation_metric = "clpp" - qtl_type = "tuqtl" - return cls( - _df=_common_colocalisation_feature_logic( - study_loci_to_annotate, - colocalisation_method, - colocalisation_metric, - cls.feature_name, - qtl_type, - **feature_dependency, - ), - _schema=cls.get_schema(), - ) - - -class EQtlColocH4MaximumFeature(L2GFeature): - """Max CLPP for each (study, locus, gene) aggregating over all eQTLs.""" - - feature_dependency_type = [Colocalisation, StudyIndex, StudyLocus] - feature_name = "eQtlColocH4Maximum" - - @classmethod - def compute( - cls: type[EQtlColocH4MaximumFeature], - study_loci_to_annotate: StudyLocus | L2GGoldStandard, - feature_dependency: dict[str, Any], - ) -> EQtlColocH4MaximumFeature: - """Computes the feature. - - Args: - study_loci_to_annotate (StudyLocus | L2GGoldStandard): The dataset containing study loci that will be used for annotation - feature_dependency (dict[str, Any]): Dataset with the colocalisation results - - Returns: - EQtlColocH4MaximumFeature: Feature dataset - """ - colocalisation_method = "Coloc" - colocalisation_metric = "h4" - qtl_type = "eqtl" - return cls( - _df=_common_colocalisation_feature_logic( - study_loci_to_annotate, - colocalisation_method, - colocalisation_metric, - cls.feature_name, - qtl_type, - **feature_dependency, - ), - _schema=cls.get_schema(), - ) - - -class PQtlColocH4MaximumFeature(L2GFeature): - """Max CLPP for each (study, locus, gene) aggregating over all pQTLs.""" - - feature_dependency_type = [Colocalisation, StudyIndex, StudyLocus] - feature_name = "pQtlColocH4Maximum" - - @classmethod - def compute( - cls: type[PQtlColocH4MaximumFeature], - study_loci_to_annotate: StudyLocus | L2GGoldStandard, - feature_dependency: dict[str, Any], - ) -> PQtlColocH4MaximumFeature: - """Computes the feature. - - Args: - study_loci_to_annotate (StudyLocus | L2GGoldStandard): The dataset containing study loci that will be used for annotation - feature_dependency (dict[str, Any]): Dataset with the colocalisation results - - Returns: - PQtlColocH4MaximumFeature: Feature dataset - """ - colocalisation_method = "Coloc" - colocalisation_metric = "h4" - qtl_type = "pqtl" - return cls( - _df=_common_colocalisation_feature_logic( - study_loci_to_annotate, - colocalisation_method, - colocalisation_metric, - cls.feature_name, - qtl_type, - **feature_dependency, - ), - _schema=cls.get_schema(), - ) - - -class SQtlColocH4MaximumFeature(L2GFeature): - """Max CLPP for each (study, locus, gene) aggregating over all sQTLs.""" - - feature_dependency_type = [Colocalisation, StudyIndex, StudyLocus] - feature_name = "sQtlColocH4Maximum" - - @classmethod - def compute( - cls: type[SQtlColocH4MaximumFeature], - study_loci_to_annotate: StudyLocus | L2GGoldStandard, - feature_dependency: dict[str, Any], - ) -> SQtlColocH4MaximumFeature: - """Computes the feature. - - Args: - study_loci_to_annotate (StudyLocus | L2GGoldStandard): The dataset containing study loci that will be used for annotation - feature_dependency (dict[str, Any]): Dataset with the colocalisation results - - Returns: - SQtlColocH4MaximumFeature: Feature dataset - """ - colocalisation_method = "Coloc" - colocalisation_metric = "h4" - qtl_type = "sqtl" - return cls( - _df=_common_colocalisation_feature_logic( - study_loci_to_annotate, - colocalisation_method, - colocalisation_metric, - cls.feature_name, - qtl_type, - **feature_dependency, - ), - _schema=cls.get_schema(), - ) - - -class TuQtlColocH4MaximumFeature(L2GFeature): - """Max H4 for each (study, locus, gene) aggregating over all tuQTLs.""" - - feature_dependency_type = [Colocalisation, StudyIndex, StudyLocus] - feature_name = "tuQtlColocH4Maximum" - - @classmethod - def compute( - cls: type[TuQtlColocH4MaximumFeature], - study_loci_to_annotate: StudyLocus | L2GGoldStandard, - feature_dependency: dict[str, Any], - ) -> TuQtlColocH4MaximumFeature: - """Computes the feature. - - Args: - study_loci_to_annotate (StudyLocus | L2GGoldStandard): The dataset containing study loci that will be used for annotation - feature_dependency (dict[str, Any]): Dataset with the colocalisation results - - Returns: - TuQtlColocH4MaximumFeature: Feature dataset - """ - colocalisation_method = "Coloc" - colocalisation_metric = "h4" - qtl_type = "tuqtl" - return cls( - _df=_common_colocalisation_feature_logic( - study_loci_to_annotate, - colocalisation_method, - colocalisation_metric, - cls.feature_name, - qtl_type, - **feature_dependency, - ), - _schema=cls.get_schema(), - ) - - -class DistanceTssMinimumFeature(L2GFeature): - """Minimum distance of all tagging variants to gene TSS.""" - - @classmethod - def compute( - cls: type[DistanceTssMinimumFeature], - study_loci_to_annotate: StudyLocus | L2GGoldStandard, - feature_dependency: V2G, - ) -> L2GFeature: - """Computes the feature. - - Args: - study_loci_to_annotate (StudyLocus | L2GGoldStandard): The dataset containing study loci that will be used for annotation - feature_dependency (V2G): Dataset that contains the distance information - - Returns: - L2GFeature: Feature dataset - - Raises: - NotImplementedError: Not implemented - """ - raise NotImplementedError - - -class DistanceTssMeanFeature(L2GFeature): - """Average distance of all tagging variants to gene TSS. - - NOTE: to be rewritten taking variant index as input - """ - - fill_na_value = 500_000 - feature_dependency_type = V2G - - @classmethod - def compute( - cls: type[DistanceTssMeanFeature], - study_loci_to_annotate: StudyLocus | L2GGoldStandard, - feature_dependency: V2G, - ) -> DistanceTssMeanFeature: - """Computes the feature. - - Args: - study_loci_to_annotate (StudyLocus | L2GGoldStandard): The dataset containing study loci that will be used for annotation - feature_dependency (V2G): Dataset that contains the distance information - - Returns: - DistanceTssMeanFeature: Feature dataset - """ - agg_expr = f.mean("weightedScore").alias("distanceTssMean") - # Everything but expresion is common logic - v2g = feature_dependency.df.filter(f.col("datasourceId") == "canonical_tss") - wide_df = ( - study_loci_to_annotate.df.withColumn( - "variantInLocus", f.explode_outer("locus") - ) - .select( - "studyLocusId", - f.col("variantInLocus.variantId").alias("variantInLocusId"), - f.col("variantInLocus.posteriorProbability").alias( - "variantInLocusPosteriorProbability" - ), - ) - .join( - v2g.selectExpr("variantId as variantInLocusId", "geneId", "score"), - on="variantInLocusId", - how="inner", - ) - .withColumn( - "weightedScore", - f.col("score") * f.col("variantInLocusPosteriorProbability"), - ) - .groupBy("studyLocusId", "geneId") - .agg(agg_expr) - ) - return cls( - _df=convert_from_wide_to_long( - wide_df, - id_vars=("studyLocusId", "geneId"), - var_name="featureName", - value_name="featureValue", - ), - _schema=cls.get_schema(), - ) diff --git a/src/gentropy/dataset/l2g_features/__init__.py b/src/gentropy/dataset/l2g_features/__init__.py new file mode 100644 index 000000000..ce15cedfe --- /dev/null +++ b/src/gentropy/dataset/l2g_features/__init__.py @@ -0,0 +1,3 @@ +"""Feature factories for L2G.""" + +from __future__ import annotations diff --git a/src/gentropy/dataset/l2g_features/colocalisation.py b/src/gentropy/dataset/l2g_features/colocalisation.py new file mode 100644 index 000000000..c44573b72 --- /dev/null +++ b/src/gentropy/dataset/l2g_features/colocalisation.py @@ -0,0 +1,791 @@ +"""Collection of methods that extract features from the colocalisation datasets.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +import pyspark.sql.functions as f +from pyspark.sql import Window + +from gentropy.common.spark_helpers import convert_from_wide_to_long +from gentropy.dataset.colocalisation import Colocalisation +from gentropy.dataset.l2g_features.l2g_feature import L2GFeature +from gentropy.dataset.l2g_gold_standard import L2GGoldStandard +from gentropy.dataset.study_index import StudyIndex +from gentropy.dataset.study_locus import StudyLocus + +if TYPE_CHECKING: + from pyspark.sql import DataFrame + + +def common_colocalisation_feature_logic( + study_loci_to_annotate: StudyLocus | L2GGoldStandard, + colocalisation_method: str, + colocalisation_metric: str, + feature_name: str, + qtl_type: str, + *, + colocalisation: Colocalisation, + study_index: StudyIndex, + study_locus: StudyLocus, +) -> DataFrame: + """Wrapper to call the logic that creates a type of colocalisation features. + + Args: + study_loci_to_annotate (StudyLocus | L2GGoldStandard): The dataset containing study loci that will be used for annotation + colocalisation_method (str): The colocalisation method to filter the data by + colocalisation_metric (str): The colocalisation metric to use + feature_name (str): The name of the feature to create + qtl_type (str): The type of QTL to filter the data by + colocalisation (Colocalisation): Dataset with the colocalisation results + study_index (StudyIndex): Study index to fetch study type and gene + study_locus (StudyLocus): Study locus to traverse between colocalisation and study index + + Returns: + DataFrame: Feature annotation in long format with the columns: studyLocusId, geneId, featureName, featureValue + """ + joining_cols = ( + ["studyLocusId", "geneId"] + if isinstance(study_loci_to_annotate, L2GGoldStandard) + else ["studyLocusId"] + ) + return ( + study_loci_to_annotate.df.join( + colocalisation.extract_maximum_coloc_probability_per_region_and_gene( + study_locus, + study_index, + filter_by_colocalisation_method=colocalisation_method, + filter_by_qtl=qtl_type, + ), + on=joining_cols, + ) + .selectExpr( + "studyLocusId", + "geneId", + f"{colocalisation_metric} as {feature_name}", + ) + .distinct() + ) + + +def common_neighbourhood_colocalisation_feature_logic( + study_loci_to_annotate: StudyLocus | L2GGoldStandard, + colocalisation_method: str, + colocalisation_metric: str, + feature_name: str, + qtl_type: str, + *, + colocalisation: Colocalisation, + study_index: StudyIndex, + study_locus: StudyLocus, +) -> DataFrame: + """Wrapper to call the logic that creates a type of colocalisation features. + + Args: + study_loci_to_annotate (StudyLocus | L2GGoldStandard): The dataset containing study loci that will be used for annotation + colocalisation_method (str): The colocalisation method to filter the data by + colocalisation_metric (str): The colocalisation metric to use + feature_name (str): The name of the feature to create + qtl_type (str): The type of QTL to filter the data by + colocalisation (Colocalisation): Dataset with the colocalisation results + study_index (StudyIndex): Study index to fetch study type and gene + study_locus (StudyLocus): Study locus to traverse between colocalisation and study index + + Returns: + DataFrame: Feature annotation in long format with the columns: studyLocusId, geneId, featureName, featureValue + """ + # First maximum colocalisation score for each studylocus, gene + local_feature_name = feature_name.replace("Neighbourhood", "") + local_max = common_colocalisation_feature_logic( + study_loci_to_annotate, + colocalisation_method, + colocalisation_metric, + local_feature_name, + qtl_type, + colocalisation=colocalisation, + study_index=study_index, + study_locus=study_locus, + ) + return ( + # Then compute maximum score in the vicinity (feature will be the same for any gene associated with a studyLocus) + local_max.withColumn( + "regional_maximum", + f.max(local_feature_name).over(Window.partitionBy("studyLocusId")), + ) + .withColumn(feature_name, f.col("regional_maximum") - f.col(local_feature_name)) + .drop("regional_maximum", local_feature_name) + ) + + +class EQtlColocClppMaximumFeature(L2GFeature): + """Max CLPP for each (study, locus, gene) aggregating over all eQTLs.""" + + feature_dependency_type = [Colocalisation, StudyIndex, StudyLocus] + feature_name = "eQtlColocClppMaximum" + + @classmethod + def compute( + cls: type[EQtlColocClppMaximumFeature], + study_loci_to_annotate: StudyLocus | L2GGoldStandard, + feature_dependency: dict[str, Any], + ) -> EQtlColocClppMaximumFeature: + """Computes the feature. + + Args: + study_loci_to_annotate (StudyLocus | L2GGoldStandard): The dataset containing study loci that will be used for annotation + feature_dependency (dict[str, Any]): Dictionary with the dependencies required. They are passed as keyword arguments. + + Returns: + EQtlColocClppMaximumFeature: Feature dataset + """ + colocalisation_method = "ECaviar" + colocalisation_metric = "clpp" + qtl_type = "eqtl" + + return cls( + _df=convert_from_wide_to_long( + common_colocalisation_feature_logic( + study_loci_to_annotate, + colocalisation_method, + colocalisation_metric, + cls.feature_name, + qtl_type, + **feature_dependency, + ), + id_vars=("studyLocusId", "geneId"), + var_name="featureName", + value_name="featureValue", + ), + _schema=cls.get_schema(), + ) + + +class EQtlColocClppMaximumNeighbourhoodFeature(L2GFeature): + """Max CLPP for each (study, locus) aggregating over all eQTLs.""" + + feature_dependency_type = [Colocalisation, StudyIndex, StudyLocus] + feature_name = "eQtlColocClppMaximumNeighbourhood" + + @classmethod + def compute( + cls: type[EQtlColocClppMaximumNeighbourhoodFeature], + study_loci_to_annotate: StudyLocus | L2GGoldStandard, + feature_dependency: dict[str, Any], + ) -> EQtlColocClppMaximumNeighbourhoodFeature: + """Computes the feature. + + Args: + study_loci_to_annotate (StudyLocus | L2GGoldStandard): The dataset containing study loci that will be used for annotation + feature_dependency (dict[str, Any]): Dictionary with the dependencies required. They are passed as keyword arguments. + + Returns: + EQtlColocClppMaximumNeighbourhoodFeature: Feature dataset + """ + colocalisation_method = "ECaviar" + colocalisation_metric = "clpp" + qtl_type = "eqtl" + + return cls( + _df=convert_from_wide_to_long( + common_neighbourhood_colocalisation_feature_logic( + study_loci_to_annotate, + colocalisation_method, + colocalisation_metric, + cls.feature_name, + qtl_type, + **feature_dependency, + ), + id_vars=("studyLocusId", "geneId"), + var_name="featureName", + value_name="featureValue", + ), + _schema=cls.get_schema(), + ) + + +class PQtlColocClppMaximumFeature(L2GFeature): + """Max CLPP for each (study, locus, gene) aggregating over all pQTLs.""" + + feature_dependency_type = [Colocalisation, StudyIndex, StudyLocus] + feature_name = "pQtlColocClppMaximum" + + @classmethod + def compute( + cls: type[PQtlColocClppMaximumFeature], + study_loci_to_annotate: StudyLocus | L2GGoldStandard, + feature_dependency: dict[str, Any], + ) -> PQtlColocClppMaximumFeature: + """Computes the feature. + + Args: + study_loci_to_annotate (StudyLocus | L2GGoldStandard): The dataset containing study loci that will be used for annotation + feature_dependency (dict[str, Any]): Dataset with the colocalisation results + + Returns: + PQtlColocClppMaximumFeature: Feature dataset + """ + colocalisation_method = "ECaviar" + colocalisation_metric = "clpp" + qtl_type = "pqtl" + return cls( + _df=convert_from_wide_to_long( + common_colocalisation_feature_logic( + study_loci_to_annotate, + colocalisation_method, + colocalisation_metric, + cls.feature_name, + qtl_type, + **feature_dependency, + ), + id_vars=("studyLocusId", "geneId"), + var_name="featureName", + value_name="featureValue", + ), + _schema=cls.get_schema(), + ) + + +class PQtlColocClppMaximumNeighbourhoodFeature(L2GFeature): + """Max CLPP for each (study, locus, gene) aggregating over all pQTLs.""" + + feature_dependency_type = [Colocalisation, StudyIndex, StudyLocus] + feature_name = "pQtlColocClppMaximumNeighbourhood" + + @classmethod + def compute( + cls: type[PQtlColocClppMaximumNeighbourhoodFeature], + study_loci_to_annotate: StudyLocus | L2GGoldStandard, + feature_dependency: dict[str, Any], + ) -> PQtlColocClppMaximumNeighbourhoodFeature: + """Computes the feature. + + Args: + study_loci_to_annotate (StudyLocus | L2GGoldStandard): The dataset containing study loci that will be used for annotation + feature_dependency (dict[str, Any]): Dataset with the colocalisation results + + Returns: + PQtlColocClppMaximumNeighbourhoodFeature: Feature dataset + """ + colocalisation_method = "ECaviar" + colocalisation_metric = "clpp" + qtl_type = "pqtl" + return cls( + _df=convert_from_wide_to_long( + common_neighbourhood_colocalisation_feature_logic( + study_loci_to_annotate, + colocalisation_method, + colocalisation_metric, + cls.feature_name, + qtl_type, + **feature_dependency, + ), + id_vars=("studyLocusId", "geneId"), + var_name="featureName", + value_name="featureValue", + ), + _schema=cls.get_schema(), + ) + + +class SQtlColocClppMaximumFeature(L2GFeature): + """Max CLPP for each (study, locus, gene) aggregating over all sQTLs.""" + + feature_dependency_type = [Colocalisation, StudyIndex, StudyLocus] + feature_name = "sQtlColocClppMaximum" + + @classmethod + def compute( + cls: type[SQtlColocClppMaximumFeature], + study_loci_to_annotate: StudyLocus | L2GGoldStandard, + feature_dependency: dict[str, Any], + ) -> SQtlColocClppMaximumFeature: + """Computes the feature. + + Args: + study_loci_to_annotate (StudyLocus | L2GGoldStandard): The dataset containing study loci that will be used for annotation + feature_dependency (dict[str, Any]): Dataset with the colocalisation results + + Returns: + SQtlColocClppMaximumFeature: Feature dataset + """ + colocalisation_method = "ECaviar" + colocalisation_metric = "clpp" + qtl_type = "sqtl" + return cls( + _df=convert_from_wide_to_long( + common_colocalisation_feature_logic( + study_loci_to_annotate, + colocalisation_method, + colocalisation_metric, + cls.feature_name, + qtl_type, + **feature_dependency, + ), + id_vars=("studyLocusId", "geneId"), + var_name="featureName", + value_name="featureValue", + ), + _schema=cls.get_schema(), + ) + + +class SQtlColocClppMaximumNeighbourhoodFeature(L2GFeature): + """Max CLPP for each (study, locus, gene) aggregating over all sQTLs.""" + + feature_dependency_type = [Colocalisation, StudyIndex, StudyLocus] + feature_name = "sQtlColocClppMaximumNeighbourhood" + + @classmethod + def compute( + cls: type[SQtlColocClppMaximumNeighbourhoodFeature], + study_loci_to_annotate: StudyLocus | L2GGoldStandard, + feature_dependency: dict[str, Any], + ) -> SQtlColocClppMaximumNeighbourhoodFeature: + """Computes the feature. + + Args: + study_loci_to_annotate (StudyLocus | L2GGoldStandard): The dataset containing study loci that will be used for annotation + feature_dependency (dict[str, Any]): Dataset with the colocalisation results + + Returns: + SQtlColocClppMaximumNeighbourhoodFeature: Feature dataset + """ + colocalisation_method = "ECaviar" + colocalisation_metric = "clpp" + qtl_type = "sqtl" + return cls( + _df=convert_from_wide_to_long( + common_neighbourhood_colocalisation_feature_logic( + study_loci_to_annotate, + colocalisation_method, + colocalisation_metric, + cls.feature_name, + qtl_type, + **feature_dependency, + ), + id_vars=("studyLocusId", "geneId"), + var_name="featureName", + value_name="featureValue", + ), + _schema=cls.get_schema(), + ) + + +class TuQtlColocClppMaximumFeature(L2GFeature): + """Max CLPP for each (study, locus, gene) aggregating over all tuQTLs.""" + + feature_dependency_type = [Colocalisation, StudyIndex, StudyLocus] + feature_name = "tuQtlColocClppMaximum" + + @classmethod + def compute( + cls: type[TuQtlColocClppMaximumFeature], + study_loci_to_annotate: StudyLocus | L2GGoldStandard, + feature_dependency: dict[str, Any], + ) -> TuQtlColocClppMaximumFeature: + """Computes the feature. + + Args: + study_loci_to_annotate (StudyLocus | L2GGoldStandard): The dataset containing study loci that will be used for annotation + feature_dependency (dict[str, Any]): Dataset with the colocalisation results + + Returns: + TuQtlColocClppMaximumFeature: Feature dataset + """ + colocalisation_method = "ECaviar" + colocalisation_metric = "clpp" + qtl_type = "tuqtl" + return cls( + _df=convert_from_wide_to_long( + common_colocalisation_feature_logic( + study_loci_to_annotate, + colocalisation_method, + colocalisation_metric, + cls.feature_name, + qtl_type, + **feature_dependency, + ), + id_vars=("studyLocusId", "geneId"), + var_name="featureName", + value_name="featureValue", + ), + _schema=cls.get_schema(), + ) + + +class TuQtlColocClppMaximumNeighbourhoodFeature(L2GFeature): + """Max CLPP for each (study, locus) aggregating over all tuQTLs.""" + + feature_dependency_type = [Colocalisation, StudyIndex, StudyLocus] + feature_name = "tuQtlColocClppMaximumNeighbourhood" + + @classmethod + def compute( + cls: type[TuQtlColocClppMaximumNeighbourhoodFeature], + study_loci_to_annotate: StudyLocus | L2GGoldStandard, + feature_dependency: dict[str, Any], + ) -> TuQtlColocClppMaximumNeighbourhoodFeature: + """Computes the feature. + + Args: + study_loci_to_annotate (StudyLocus | L2GGoldStandard): The dataset containing study loci that will be used for annotation + feature_dependency (dict[str, Any]): Dataset with the colocalisation results + + Returns: + TuQtlColocClppMaximumNeighbourhoodFeature: Feature dataset + """ + colocalisation_method = "ECaviar" + colocalisation_metric = "clpp" + qtl_type = "tuqtl" + return cls( + _df=convert_from_wide_to_long( + common_neighbourhood_colocalisation_feature_logic( + study_loci_to_annotate, + colocalisation_method, + colocalisation_metric, + cls.feature_name, + qtl_type, + **feature_dependency, + ), + id_vars=("studyLocusId", "geneId"), + var_name="featureName", + value_name="featureValue", + ), + _schema=cls.get_schema(), + ) + + +class EQtlColocH4MaximumFeature(L2GFeature): + """Max H4 for each (study, locus, gene) aggregating over all eQTLs.""" + + feature_dependency_type = [Colocalisation, StudyIndex, StudyLocus] + feature_name = "eQtlColocH4Maximum" + + @classmethod + def compute( + cls: type[EQtlColocH4MaximumFeature], + study_loci_to_annotate: StudyLocus | L2GGoldStandard, + feature_dependency: dict[str, Any], + ) -> EQtlColocH4MaximumFeature: + """Computes the feature. + + Args: + study_loci_to_annotate (StudyLocus | L2GGoldStandard): The dataset containing study loci that will be used for annotation + feature_dependency (dict[str, Any]): Dataset with the colocalisation results + + Returns: + EQtlColocH4MaximumFeature: Feature dataset + """ + colocalisation_method = "Coloc" + colocalisation_metric = "h4" + qtl_type = "eqtl" + return cls( + _df=convert_from_wide_to_long( + common_colocalisation_feature_logic( + study_loci_to_annotate, + colocalisation_method, + colocalisation_metric, + cls.feature_name, + qtl_type, + **feature_dependency, + ), + id_vars=("studyLocusId", "geneId"), + var_name="featureName", + value_name="featureValue", + ), + _schema=cls.get_schema(), + ) + + +class EQtlColocH4MaximumNeighbourhoodFeature(L2GFeature): + """Max H4 for each (study, locus) aggregating over all eQTLs.""" + + feature_dependency_type = [Colocalisation, StudyIndex, StudyLocus] + feature_name = "eQtlColocH4MaximumNeighbourhood" + + @classmethod + def compute( + cls: type[EQtlColocH4MaximumNeighbourhoodFeature], + study_loci_to_annotate: StudyLocus | L2GGoldStandard, + feature_dependency: dict[str, Any], + ) -> EQtlColocH4MaximumNeighbourhoodFeature: + """Computes the feature. + + Args: + study_loci_to_annotate (StudyLocus | L2GGoldStandard): The dataset containing study loci that will be used for annotation + feature_dependency (dict[str, Any]): Dataset with the colocalisation results + + Returns: + EQtlColocH4MaximumNeighbourhoodFeature: Feature dataset + """ + colocalisation_method = "Coloc" + colocalisation_metric = "h4" + qtl_type = "eqtl" + return cls( + _df=convert_from_wide_to_long( + common_neighbourhood_colocalisation_feature_logic( + study_loci_to_annotate, + colocalisation_method, + colocalisation_metric, + cls.feature_name, + qtl_type, + **feature_dependency, + ), + id_vars=("studyLocusId", "geneId"), + var_name="featureName", + value_name="featureValue", + ), + _schema=cls.get_schema(), + ) + + +class PQtlColocH4MaximumFeature(L2GFeature): + """Max H4 for each (study, locus, gene) aggregating over all pQTLs.""" + + feature_dependency_type = [Colocalisation, StudyIndex, StudyLocus] + feature_name = "pQtlColocH4Maximum" + + @classmethod + def compute( + cls: type[PQtlColocH4MaximumFeature], + study_loci_to_annotate: StudyLocus | L2GGoldStandard, + feature_dependency: dict[str, Any], + ) -> PQtlColocH4MaximumFeature: + """Computes the feature. + + Args: + study_loci_to_annotate (StudyLocus | L2GGoldStandard): The dataset containing study loci that will be used for annotation + feature_dependency (dict[str, Any]): Dataset with the colocalisation results + + Returns: + PQtlColocH4MaximumFeature: Feature dataset + """ + colocalisation_method = "Coloc" + colocalisation_metric = "h4" + qtl_type = "pqtl" + return cls( + _df=convert_from_wide_to_long( + common_colocalisation_feature_logic( + study_loci_to_annotate, + colocalisation_method, + colocalisation_metric, + cls.feature_name, + qtl_type, + **feature_dependency, + ), + id_vars=("studyLocusId", "geneId"), + var_name="featureName", + value_name="featureValue", + ), + _schema=cls.get_schema(), + ) + + +class PQtlColocH4MaximumNeighbourhoodFeature(L2GFeature): + """Max H4 for each (study, locus) aggregating over all pQTLs.""" + + feature_dependency_type = [Colocalisation, StudyIndex, StudyLocus] + feature_name = "pQtlColocH4MaximumNeighbourhood" + + @classmethod + def compute( + cls: type[PQtlColocH4MaximumNeighbourhoodFeature], + study_loci_to_annotate: StudyLocus | L2GGoldStandard, + feature_dependency: dict[str, Any], + ) -> PQtlColocH4MaximumNeighbourhoodFeature: + """Computes the feature. + + Args: + study_loci_to_annotate (StudyLocus | L2GGoldStandard): The dataset containing study loci that will be used for annotation + feature_dependency (dict[str, Any]): Dataset with the colocalisation results + + Returns: + PQtlColocH4MaximumNeighbourhoodFeature: Feature dataset + """ + colocalisation_method = "Coloc" + colocalisation_metric = "h4" + qtl_type = "pqtl" + return cls( + _df=convert_from_wide_to_long( + common_neighbourhood_colocalisation_feature_logic( + study_loci_to_annotate, + colocalisation_method, + colocalisation_metric, + cls.feature_name, + qtl_type, + **feature_dependency, + ), + id_vars=("studyLocusId", "geneId"), + var_name="featureName", + value_name="featureValue", + ), + _schema=cls.get_schema(), + ) + + +class SQtlColocH4MaximumFeature(L2GFeature): + """Max H4 for each (study, locus, gene) aggregating over all sQTLs.""" + + feature_dependency_type = [Colocalisation, StudyIndex, StudyLocus] + feature_name = "sQtlColocH4Maximum" + + @classmethod + def compute( + cls: type[SQtlColocH4MaximumFeature], + study_loci_to_annotate: StudyLocus | L2GGoldStandard, + feature_dependency: dict[str, Any], + ) -> SQtlColocH4MaximumFeature: + """Computes the feature. + + Args: + study_loci_to_annotate (StudyLocus | L2GGoldStandard): The dataset containing study loci that will be used for annotation + feature_dependency (dict[str, Any]): Dataset with the colocalisation results + + Returns: + SQtlColocH4MaximumFeature: Feature dataset + """ + colocalisation_method = "Coloc" + colocalisation_metric = "h4" + qtl_type = "sqtl" + return cls( + _df=convert_from_wide_to_long( + common_colocalisation_feature_logic( + study_loci_to_annotate, + colocalisation_method, + colocalisation_metric, + cls.feature_name, + qtl_type, + **feature_dependency, + ), + id_vars=("studyLocusId", "geneId"), + var_name="featureName", + value_name="featureValue", + ), + _schema=cls.get_schema(), + ) + + +class SQtlColocH4MaximumNeighbourhoodFeature(L2GFeature): + """Max H4 for each (study, locus) aggregating over all sQTLs.""" + + feature_dependency_type = [Colocalisation, StudyIndex, StudyLocus] + feature_name = "sQtlColocH4MaximumNeighbourhood" + + @classmethod + def compute( + cls: type[SQtlColocH4MaximumNeighbourhoodFeature], + study_loci_to_annotate: StudyLocus | L2GGoldStandard, + feature_dependency: dict[str, Any], + ) -> SQtlColocH4MaximumNeighbourhoodFeature: + """Computes the feature. + + Args: + study_loci_to_annotate (StudyLocus | L2GGoldStandard): The dataset containing study loci that will be used for annotation + feature_dependency (dict[str, Any]): Dataset with the colocalisation results + + Returns: + SQtlColocH4MaximumNeighbourhoodFeature: Feature dataset + """ + colocalisation_method = "Coloc" + colocalisation_metric = "h4" + qtl_type = "sqtl" + return cls( + _df=convert_from_wide_to_long( + common_neighbourhood_colocalisation_feature_logic( + study_loci_to_annotate, + colocalisation_method, + colocalisation_metric, + cls.feature_name, + qtl_type, + **feature_dependency, + ), + id_vars=("studyLocusId", "geneId"), + var_name="featureName", + value_name="featureValue", + ), + _schema=cls.get_schema(), + ) + + +class TuQtlColocH4MaximumFeature(L2GFeature): + """Max H4 for each (study, locus, gene) aggregating over all tuQTLs.""" + + feature_dependency_type = [Colocalisation, StudyIndex, StudyLocus] + feature_name = "tuQtlColocH4Maximum" + + @classmethod + def compute( + cls: type[TuQtlColocH4MaximumFeature], + study_loci_to_annotate: StudyLocus | L2GGoldStandard, + feature_dependency: dict[str, Any], + ) -> TuQtlColocH4MaximumFeature: + """Computes the feature. + + Args: + study_loci_to_annotate (StudyLocus | L2GGoldStandard): The dataset containing study loci that will be used for annotation + feature_dependency (dict[str, Any]): Dataset with the colocalisation results + + Returns: + TuQtlColocH4MaximumFeature: Feature dataset + """ + colocalisation_method = "Coloc" + colocalisation_metric = "h4" + qtl_type = "tuqtl" + return cls( + _df=convert_from_wide_to_long( + common_colocalisation_feature_logic( + study_loci_to_annotate, + colocalisation_method, + colocalisation_metric, + cls.feature_name, + qtl_type, + **feature_dependency, + ), + id_vars=("studyLocusId", "geneId"), + var_name="featureName", + value_name="featureValue", + ), + _schema=cls.get_schema(), + ) + + +class TuQtlColocH4MaximumNeighbourhoodFeature(L2GFeature): + """Max H4 for each (study, locus) aggregating over all tuQTLs.""" + + feature_dependency_type = [Colocalisation, StudyIndex, StudyLocus] + feature_name = "tuQtlColocH4MaximumNeighbourhood" + + @classmethod + def compute( + cls: type[TuQtlColocH4MaximumNeighbourhoodFeature], + study_loci_to_annotate: StudyLocus | L2GGoldStandard, + feature_dependency: dict[str, Any], + ) -> TuQtlColocH4MaximumNeighbourhoodFeature: + """Computes the feature. + + Args: + study_loci_to_annotate (StudyLocus | L2GGoldStandard): The dataset containing study loci that will be used for annotation + feature_dependency (dict[str, Any]): Dataset with the colocalisation results + + Returns: + TuQtlColocH4MaximumNeighbourhoodFeature: Feature dataset + """ + colocalisation_method = "Coloc" + colocalisation_metric = "h4" + qtl_type = "tuqtl" + return cls( + _df=convert_from_wide_to_long( + common_colocalisation_feature_logic( + study_loci_to_annotate, + colocalisation_method, + colocalisation_metric, + cls.feature_name, + qtl_type, + **feature_dependency, + ), + id_vars=("studyLocusId", "geneId"), + var_name="featureName", + value_name="featureValue", + ), + _schema=cls.get_schema(), + ) diff --git a/src/gentropy/dataset/l2g_features/distance.py b/src/gentropy/dataset/l2g_features/distance.py new file mode 100644 index 000000000..ea030108c --- /dev/null +++ b/src/gentropy/dataset/l2g_features/distance.py @@ -0,0 +1,422 @@ +"""Collection of methods that extract distance features from the variant index dataset.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +import pyspark.sql.functions as f +from pyspark.sql import Window + +from gentropy.common.spark_helpers import convert_from_wide_to_long +from gentropy.dataset.l2g_features.l2g_feature import L2GFeature +from gentropy.dataset.l2g_gold_standard import L2GGoldStandard +from gentropy.dataset.study_locus import StudyLocus +from gentropy.dataset.variant_index import VariantIndex + +if TYPE_CHECKING: + from pyspark.sql import DataFrame + + +def common_distance_feature_logic( + study_loci_to_annotate: StudyLocus | L2GGoldStandard, + *, + variant_index: VariantIndex, + feature_name: str, + distance_type: str, + genomic_window: int = 500_000, +) -> DataFrame: + """Calculate the distance feature that correlates a variant in a credible set with a gene. + + The distance is weighted by the posterior probability of the variant to factor in its contribution to the trait when we look at the average distance score for all variants in the credible set. + + Args: + study_loci_to_annotate (StudyLocus | L2GGoldStandard): The dataset containing study loci that will be used for annotation + variant_index (VariantIndex): The dataset containing distance to gene information + feature_name (str): The name of the feature + distance_type (str): The type of distance to gene + genomic_window (int): The maximum window size to consider + + Returns: + DataFrame: Feature dataset + """ + distances_dataset = variant_index.get_distance_to_gene(distance_type=distance_type) + if "Mean" in feature_name: + # Weighting by the SNP contribution is only applied when we are averaging all distances + distance_score_expr = ( + f.lit(genomic_window) - f.col(distance_type) + f.lit(1) + ) * f.col("posteriorProbability") + agg_expr = f.mean(f.col("distance_score")) + df = study_loci_to_annotate.df.withColumn( + "variantInLocus", f.explode_outer("locus") + ).select( + "studyLocusId", + f.col("variantInLocus.variantId").alias("variantId"), + f.col("variantInLocus.posteriorProbability").alias("posteriorProbability"), + ) + elif "Sentinel" in feature_name: + # For minimum distances we calculate the unweighted distance between the sentinel (lead) and the gene. This + distance_score_expr = f.lit(genomic_window) - f.col(distance_type) + f.lit(1) + agg_expr = f.first(f.col("distance_score")) + df = study_loci_to_annotate.df.select("studyLocusId", "variantId") + return ( + df.join( + distances_dataset.withColumnRenamed("targetId", "geneId"), + on="variantId", + how="inner", + ) + .withColumn("distance_score", f.log10(distance_score_expr)) + .groupBy("studyLocusId", "geneId") + .agg(agg_expr.alias(feature_name)) + ) + + +def common_neighbourhood_distance_feature_logic( + study_loci_to_annotate: StudyLocus | L2GGoldStandard, + *, + variant_index: VariantIndex, + feature_name: str, + distance_type: str, + genomic_window: int = 500_000, +) -> DataFrame: + """Calculate the distance feature that correlates any variant in a credible set with any gene nearby the locus. The distance is weighted by the posterior probability of the variant to factor in its contribution to the trait. + + Args: + study_loci_to_annotate (StudyLocus | L2GGoldStandard): The dataset containing study loci that will be used for annotation + variant_index (VariantIndex): The dataset containing distance to gene information + feature_name (str): The name of the feature + distance_type (str): The type of distance to gene + genomic_window (int): The maximum window size to consider + + Returns: + DataFrame: Feature dataset + """ + local_feature_name = feature_name.replace("Neighbourhood", "") + # First compute mean distances to a gene + local_metric = common_distance_feature_logic( + study_loci_to_annotate, + feature_name=local_feature_name, + distance_type=distance_type, + variant_index=variant_index, + genomic_window=genomic_window, + ) + return ( + # Then compute mean distance in the vicinity (feature will be the same for any gene associated with a studyLocus) + local_metric.withColumn( + "regional_metric", + f.mean(f.col(local_feature_name)).over(Window.partitionBy("studyLocusId")), + ) + .withColumn(feature_name, f.col(local_feature_name) - f.col("regional_metric")) + .drop("regional_metric", local_feature_name) + ) + + +class DistanceTssMeanFeature(L2GFeature): + """Average distance of all tagging variants to gene TSS.""" + + fill_na_value = 500_000 + feature_dependency_type = VariantIndex + feature_name = "distanceTssMean" + + @classmethod + def compute( + cls: type[DistanceTssMeanFeature], + study_loci_to_annotate: StudyLocus | L2GGoldStandard, + feature_dependency: dict[str, Any], + ) -> DistanceTssMeanFeature: + """Computes the feature. + + Args: + study_loci_to_annotate (StudyLocus | L2GGoldStandard): The dataset containing study loci that will be used for annotation + feature_dependency (dict[str, Any]): Dataset that contains the distance information + + Returns: + DistanceTssMeanFeature: Feature dataset + """ + distance_type = "distanceFromTss" + return cls( + _df=convert_from_wide_to_long( + common_distance_feature_logic( + study_loci_to_annotate, + feature_name=cls.feature_name, + distance_type=distance_type, + **feature_dependency, + ), + id_vars=("studyLocusId", "geneId"), + var_name="featureName", + value_name="featureValue", + ), + _schema=cls.get_schema(), + ) + + +class DistanceTssMeanNeighbourhoodFeature(L2GFeature): + """Minimum mean distance to TSS for all genes in the vicinity of a studyLocus.""" + + fill_na_value = 500_000 + feature_dependency_type = VariantIndex + feature_name = "distanceTssMeanNeighbourhood" + + @classmethod + def compute( + cls: type[DistanceTssMeanNeighbourhoodFeature], + study_loci_to_annotate: StudyLocus | L2GGoldStandard, + feature_dependency: dict[str, Any], + ) -> DistanceTssMeanNeighbourhoodFeature: + """Computes the feature. + + Args: + study_loci_to_annotate (StudyLocus | L2GGoldStandard): The dataset containing study loci that will be used for annotation + feature_dependency (dict[str, Any]): Dataset that contains the distance information + + Returns: + DistanceTssMeanNeighbourhoodFeature: Feature dataset + """ + distance_type = "distanceFromTss" + return cls( + _df=convert_from_wide_to_long( + common_neighbourhood_distance_feature_logic( + study_loci_to_annotate, + feature_name=cls.feature_name, + distance_type=distance_type, + **feature_dependency, + ), + id_vars=("studyLocusId", "geneId"), + var_name="featureName", + value_name="featureValue", + ), + _schema=cls.get_schema(), + ) + + +class DistanceSentinelTssFeature(L2GFeature): + """Distance of the sentinel variant to gene TSS. This is not weighted by the causal probability.""" + + fill_na_value = 500_000 + feature_dependency_type = VariantIndex + feature_name = "distanceSentinelTss" + + @classmethod + def compute( + cls: type[DistanceSentinelTssFeature], + study_loci_to_annotate: StudyLocus | L2GGoldStandard, + feature_dependency: dict[str, Any], + ) -> DistanceSentinelTssFeature: + """Computes the feature. + + Args: + study_loci_to_annotate (StudyLocus | L2GGoldStandard): The dataset containing study loci that will be used for annotation + feature_dependency (dict[str, Any]): Dataset that contains the distance information + + Returns: + DistanceSentinelTssFeature: Feature dataset + """ + distance_type = "distanceFromTss" + return cls( + _df=convert_from_wide_to_long( + common_distance_feature_logic( + study_loci_to_annotate, + feature_name=cls.feature_name, + distance_type=distance_type, + **feature_dependency, + ), + id_vars=("studyLocusId", "geneId"), + var_name="featureName", + value_name="featureValue", + ), + _schema=cls.get_schema(), + ) + + +class DistanceSentinelTssNeighbourhoodFeature(L2GFeature): + """Distance between the sentinel variant and a gene TSS as a relation of the distnace with all the genes in the vicinity of a studyLocus. This is not weighted by the causal probability.""" + + fill_na_value = 500_000 + feature_dependency_type = VariantIndex + feature_name = "distanceSentinelTssNeighbourhood" + + @classmethod + def compute( + cls: type[DistanceSentinelTssNeighbourhoodFeature], + study_loci_to_annotate: StudyLocus | L2GGoldStandard, + feature_dependency: dict[str, Any], + ) -> DistanceSentinelTssNeighbourhoodFeature: + """Computes the feature. + + Args: + study_loci_to_annotate (StudyLocus | L2GGoldStandard): The dataset containing study loci that will be used for annotation + feature_dependency (dict[str, Any]): Dataset that contains the distance information + + Returns: + DistanceSentinelTssNeighbourhoodFeature: Feature dataset + """ + distance_type = "distanceFromTss" + return cls( + _df=convert_from_wide_to_long( + common_neighbourhood_distance_feature_logic( + study_loci_to_annotate, + feature_name=cls.feature_name, + distance_type=distance_type, + **feature_dependency, + ), + id_vars=("studyLocusId", "geneId"), + var_name="featureName", + value_name="featureValue", + ), + _schema=cls.get_schema(), + ) + + +class DistanceFootprintMeanFeature(L2GFeature): + """Average distance of all tagging variants to the footprint of a gene.""" + + fill_na_value = 500_000 + feature_dependency_type = VariantIndex + feature_name = "distanceFootprintMean" + + @classmethod + def compute( + cls: type[DistanceFootprintMeanFeature], + study_loci_to_annotate: StudyLocus | L2GGoldStandard, + feature_dependency: dict[str, Any], + ) -> DistanceFootprintMeanFeature: + """Computes the feature. + + Args: + study_loci_to_annotate (StudyLocus | L2GGoldStandard): The dataset containing study loci that will be used for annotation + feature_dependency (dict[str, Any]): Dataset that contains the distance information + + Returns: + DistanceFootprintMeanFeature: Feature dataset + """ + distance_type = "distanceFromFootprint" + return cls( + _df=convert_from_wide_to_long( + common_distance_feature_logic( + study_loci_to_annotate, + feature_name=cls.feature_name, + distance_type=distance_type, + **feature_dependency, + ), + id_vars=("studyLocusId", "geneId"), + var_name="featureName", + value_name="featureValue", + ), + _schema=cls.get_schema(), + ) + + +class DistanceFootprintMeanNeighbourhoodFeature(L2GFeature): + """Minimum mean distance to footprint for all genes in the vicinity of a studyLocus.""" + + fill_na_value = 500_000 + feature_dependency_type = VariantIndex + feature_name = "distanceFootprintMeanNeighbourhood" + + @classmethod + def compute( + cls: type[DistanceFootprintMeanNeighbourhoodFeature], + study_loci_to_annotate: StudyLocus | L2GGoldStandard, + feature_dependency: dict[str, Any], + ) -> DistanceFootprintMeanNeighbourhoodFeature: + """Computes the feature. + + Args: + study_loci_to_annotate (StudyLocus | L2GGoldStandard): The dataset containing study loci that will be used for annotation + feature_dependency (dict[str, Any]): Dataset that contains the distance information + + Returns: + DistanceFootprintMeanNeighbourhoodFeature: Feature dataset + """ + distance_type = "distanceFromFootprint" + return cls( + _df=convert_from_wide_to_long( + common_neighbourhood_distance_feature_logic( + study_loci_to_annotate, + feature_name=cls.feature_name, + distance_type=distance_type, + **feature_dependency, + ), + id_vars=("studyLocusId", "geneId"), + var_name="featureName", + value_name="featureValue", + ), + _schema=cls.get_schema(), + ) + + +class DistanceSentinelFootprintFeature(L2GFeature): + """Distance between the sentinel variant and the footprint of a gene.""" + + fill_na_value = 500_000 + feature_dependency_type = VariantIndex + feature_name = "distanceSentinelFootprintMinimum" + + @classmethod + def compute( + cls: type[DistanceSentinelFootprintFeature], + study_loci_to_annotate: StudyLocus | L2GGoldStandard, + feature_dependency: dict[str, Any], + ) -> DistanceSentinelFootprintFeature: + """Computes the feature. + + Args: + study_loci_to_annotate (StudyLocus | L2GGoldStandard): The dataset containing study loci that will be used for annotation + feature_dependency (dict[str, Any]): Dataset that contains the distance information + + Returns: + DistanceSentinelFootprintFeature: Feature dataset + """ + distance_type = "distanceFromFootprint" + return cls( + _df=convert_from_wide_to_long( + common_distance_feature_logic( + study_loci_to_annotate, + feature_name=cls.feature_name, + distance_type=distance_type, + **feature_dependency, + ), + id_vars=("studyLocusId", "geneId"), + var_name="featureName", + value_name="featureValue", + ), + _schema=cls.get_schema(), + ) + + +class DistanceSentinelFootprintNeighbourhoodFeature(L2GFeature): + """Distance between the sentinel variant and a gene footprint as a relation of the distnace with all the genes in the vicinity of a studyLocus. This is not weighted by the causal probability.""" + + fill_na_value = 500_000 + feature_dependency_type = VariantIndex + feature_name = "DistanceSentinelFootprintNeighbourhoodFeature" + + @classmethod + def compute( + cls: type[DistanceSentinelFootprintNeighbourhoodFeature], + study_loci_to_annotate: StudyLocus | L2GGoldStandard, + feature_dependency: dict[str, Any], + ) -> DistanceSentinelFootprintNeighbourhoodFeature: + """Computes the feature. + + Args: + study_loci_to_annotate (StudyLocus | L2GGoldStandard): The dataset containing study loci that will be used for annotation + feature_dependency (dict[str, Any]): Dataset that contains the distance information + + Returns: + DistanceSentinelFootprintNeighbourhoodFeature: Feature dataset + """ + distance_type = "distanceFromFootprint" + return cls( + _df=convert_from_wide_to_long( + common_neighbourhood_distance_feature_logic( + study_loci_to_annotate, + feature_name=cls.feature_name, + distance_type=distance_type, + **feature_dependency, + ), + id_vars=("studyLocusId", "geneId"), + var_name="featureName", + value_name="featureValue", + ), + _schema=cls.get_schema(), + ) diff --git a/src/gentropy/dataset/l2g_features/l2g_feature.py b/src/gentropy/dataset/l2g_features/l2g_feature.py new file mode 100644 index 000000000..7073ca758 --- /dev/null +++ b/src/gentropy/dataset/l2g_features/l2g_feature.py @@ -0,0 +1,65 @@ +"""L2G Feature Dataset with a collection of methods that extract features from the gentropy datasets to be fed in L2G.""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any + +from gentropy.common.schemas import parse_spark_schema +from gentropy.dataset.dataset import Dataset + +if TYPE_CHECKING: + from pyspark.sql.types import StructType + + from gentropy.dataset.l2g_gold_standard import L2GGoldStandard + from gentropy.dataset.study_locus import StudyLocus + + +@dataclass +class L2GFeature(Dataset, ABC): + """Locus-to-gene feature dataset that serves as template to generate each of the features that inform about locus to gene assignments.""" + + def __post_init__( + self: L2GFeature, + feature_dependency_type: Any = None, + credible_set: StudyLocus | None = None, + ) -> None: + """Initializes a L2GFeature dataset. Any child class of L2GFeature must implement the `compute` method. + + Args: + feature_dependency_type (Any): The dependency that the L2GFeature dataset depends on. Defaults to None. + credible_set (StudyLocus | None): The credible set that the L2GFeature dataset is based on. Defaults to None. + """ + super().__post_init__() + self.feature_dependency_type = feature_dependency_type + self.credible_set = credible_set + + @classmethod + def get_schema(cls: type[L2GFeature]) -> StructType: + """Provides the schema for the L2GFeature dataset. + + Returns: + StructType: Schema for the L2GFeature dataset + """ + return parse_spark_schema("l2g_feature.json") + + @classmethod + @abstractmethod + def compute( + cls: type[L2GFeature], + study_loci_to_annotate: StudyLocus | L2GGoldStandard, + feature_dependency: Any, + ) -> L2GFeature: + """Computes the L2GFeature dataset. + + Args: + study_loci_to_annotate (StudyLocus | L2GGoldStandard): The dataset containing study loci that will be used for annotation + feature_dependency (Any): The dependency that the L2GFeature class needs to compute the feature + Returns: + L2GFeature: a L2GFeature dataset + + Raises: + NotImplementedError: This method must be implemented in the child classes + """ + raise NotImplementedError("Must be implemented in the child classes") diff --git a/src/gentropy/dataset/l2g_gold_standard.py b/src/gentropy/dataset/l2g_gold_standard.py index 89f4c5f5d..064f6cc0e 100644 --- a/src/gentropy/dataset/l2g_gold_standard.py +++ b/src/gentropy/dataset/l2g_gold_standard.py @@ -18,7 +18,7 @@ from gentropy.dataset.l2g_feature_matrix import L2GFeatureMatrix from gentropy.dataset.study_locus_overlap import StudyLocusOverlap - from gentropy.dataset.v2g import V2G + from gentropy.dataset.variant_index import VariantIndex @dataclass @@ -33,16 +33,16 @@ class L2GGoldStandard(Dataset): def from_otg_curation( cls: type[L2GGoldStandard], gold_standard_curation: DataFrame, - v2g: V2G, study_locus_overlap: StudyLocusOverlap, + variant_index: VariantIndex, interactions: DataFrame, ) -> L2GGoldStandard: """Initialise L2GGoldStandard from source dataset. Args: gold_standard_curation (DataFrame): Gold standard curation dataframe, extracted from - v2g (V2G): Variant to gene dataset to bring distance between a variant and a gene's TSS study_locus_overlap (StudyLocusOverlap): Study locus overlap dataset to remove duplicated loci + variant_index (VariantIndex): Dataset to bring distance between a variant and a gene's footprint interactions (DataFrame): Gene-gene interactions dataset to remove negative cases where the gene interacts with a positive gene Returns: @@ -55,7 +55,9 @@ def from_otg_curation( interactions_df = cls.process_gene_interactions(interactions) return ( - OpenTargetsL2GGoldStandard.as_l2g_gold_standard(gold_standard_curation, v2g) + OpenTargetsL2GGoldStandard.as_l2g_gold_standard( + gold_standard_curation, variant_index + ) # .filter_unique_associations(study_locus_overlap) .remove_false_negatives(interactions_df) ) diff --git a/src/gentropy/dataset/v2g.py b/src/gentropy/dataset/v2g.py deleted file mode 100644 index 04bad2113..000000000 --- a/src/gentropy/dataset/v2g.py +++ /dev/null @@ -1,51 +0,0 @@ -"""V2G dataset.""" -from __future__ import annotations - -from dataclasses import dataclass -from typing import TYPE_CHECKING - -import pyspark.sql.functions as f - -from gentropy.common.schemas import parse_spark_schema -from gentropy.dataset.dataset import Dataset - -if TYPE_CHECKING: - from pyspark.sql.types import StructType - - from gentropy.dataset.gene_index import GeneIndex - - -@dataclass -class V2G(Dataset): - """Variant-to-gene (V2G) evidence dataset. - - A variant-to-gene (V2G) evidence is understood as any piece of evidence that supports the association of a variant with a likely causal gene. The evidence can sometimes be context-specific and refer to specific `biofeatures` (e.g. cell types) - """ - - @classmethod - def get_schema(cls: type[V2G]) -> StructType: - """Provides the schema for the V2G dataset. - - Returns: - StructType: Schema for the V2G dataset - """ - return parse_spark_schema("v2g.json") - - def filter_by_genes(self: V2G, genes: GeneIndex) -> V2G: - """Filter V2G dataset by genes. - - Args: - genes (GeneIndex): Gene index dataset to filter by - - Returns: - V2G: V2G dataset filtered by genes - """ - self.df = self._df.join(genes.df.select("geneId"), on="geneId", how="inner") - return self - - def extract_distance_tss_minimum(self: V2G) -> None: - """Extract minimum distance to TSS.""" - self.df = self._df.filter(f.col("distance")).withColumn( - "distanceTssMinimum", - f.expr("min(distTss) OVER (PARTITION BY studyLocusId)"), - ) diff --git a/src/gentropy/dataset/variant_index.py b/src/gentropy/dataset/variant_index.py index 2f24cd985..4d53d741a 100644 --- a/src/gentropy/dataset/variant_index.py +++ b/src/gentropy/dataset/variant_index.py @@ -12,18 +12,15 @@ from gentropy.common.spark_helpers import ( get_nested_struct_schema, get_record_with_maximum_value, - normalise_column, rename_all_columns, safe_array_union, ) from gentropy.dataset.dataset import Dataset -from gentropy.dataset.v2g import V2G if TYPE_CHECKING: from pyspark.sql import Column, DataFrame from pyspark.sql.types import StructType - from gentropy.dataset.gene_index import GeneIndex @dataclass @@ -231,165 +228,106 @@ def filter_by_variant(self: VariantIndex, df: DataFrame) -> VariantIndex: _schema=self.schema, ) - def get_transcript_consequence_df( - self: VariantIndex, gene_index: GeneIndex | None = None - ) -> DataFrame: - """Dataframe of exploded transcript consequences. - - Optionally the trancript consequences can be reduced to the universe of a gene index. - - Args: - gene_index (GeneIndex | None): A gene index. Defaults to None. - - Returns: - DataFrame: A dataframe exploded by transcript consequences with the columns variantId, chromosome, transcriptConsequence - """ - # exploding the array removes records without VEP annotation - transript_consequences = self.df.withColumn( - "transcriptConsequence", f.explode("transcriptConsequences") - ).select( - "variantId", - "chromosome", - "position", - "transcriptConsequence", - f.col("transcriptConsequence.targetId").alias("geneId"), - ) - if gene_index: - transript_consequences = transript_consequences.join( - f.broadcast(gene_index.df), - on=["chromosome", "geneId"], - ) - return transript_consequences - - def get_distance_to_tss( + def get_distance_to_gene( self: VariantIndex, - gene_index: GeneIndex, + *, + distance_type: str = "distanceFromTss", max_distance: int = 500_000, - ) -> V2G: - """Extracts variant to gene assignments for variants falling within a window of a gene's TSS. + ) -> DataFrame: + """Extracts variant to gene assignments for variants falling within a window of a gene's TSS or footprint. Args: - gene_index (GeneIndex): A gene index to filter by. - max_distance (int): The maximum distance from the TSS to consider. Defaults to 500_000. + distance_type (str): The type of distance to use. Can be "distanceFromTss" or "distanceFromFootprint". Defaults to "distanceFromTss". + max_distance (int): The maximum distance to consider. Defaults to 500_000, the default window size for VEP. Returns: - V2G: variant to gene assignments with their distance to the TSS - """ - return V2G( - _df=( - self.df.alias("variant") - .join( - f.broadcast(gene_index.locations_lut()).alias("gene"), - on=[ - f.col("variant.chromosome") == f.col("gene.chromosome"), - f.abs(f.col("variant.position") - f.col("gene.tss")) - <= max_distance, - ], - how="inner", - ) - .withColumn( - "distance", f.abs(f.col("variant.position") - f.col("gene.tss")) - ) - .withColumn( - "inverse_distance", - max_distance - f.col("distance"), - ) - .transform(lambda df: normalise_column(df, "inverse_distance", "score")) - .select( - "variantId", - f.col("variant.chromosome").alias("chromosome"), - "distance", - "geneId", - "score", - f.lit("distance").alias("datatypeId"), - f.lit("canonical_tss").alias("datasourceId"), - ) - ), - _schema=V2G.get_schema(), - ) + DataFrame: A dataframe with the distance between a variant and a gene's TSS or footprint. - def get_plof_v2g(self: VariantIndex, gene_index: GeneIndex) -> V2G: - """Creates a dataset with variant to gene assignments with a flag indicating if the variant is predicted to be a loss-of-function variant by the LOFTEE algorithm. + Raises: + ValueError: Invalid distance type. + """ + if distance_type not in {"distanceFromTss", "distanceFromFootprint"}: + raise ValueError( + f"Invalid distance_type: {distance_type}. Must be 'distanceFromTss' or 'distanceFromFootprint'." + ) + df = self.df.select( + "variantId", f.explode("transcriptConsequences").alias("tc") + ).select("variantId", "tc.targetId", f"tc.{distance_type}") + if max_distance == 500_000: + return df + elif max_distance < 500_000: + return df.filter(f"{distance_type} <= {max_distance}") + else: + raise ValueError( + f"max_distance must be less than 500_000. Got {max_distance}." + ) - Optionally the trancript consequences can be reduced to the universe of a gene index. + def get_loftee(self: VariantIndex) -> DataFrame: + """Returns a dataframe with a flag indicating whether a variant is predicted to cause loss of function in a gene. The source of this information is the LOFTEE algorithm (https://github.com/konradjk/loftee). - Args: - gene_index (GeneIndex): A gene index to filter by. + !!! note, "This will return a filtered dataframe with only variants that have been annotated by LOFTEE." Returns: - V2G: variant to gene assignments from the LOFTEE algorithm + DataFrame: variant to gene assignments from the LOFTEE algorithm """ - return V2G( - _df=( - self.get_transcript_consequence_df(gene_index) - .filter(f.col("transcriptConsequence.lofteePrediction").isNotNull()) - .withColumn( - "isHighQualityPlof", - f.when( - f.col("transcriptConsequence.lofteePrediction") == "HC", True - ).when( - f.col("transcriptConsequence.lofteePrediction") == "LC", False - ), - ) - .withColumn( - "score", - f.when(f.col("isHighQualityPlof"), 1.0).when( - ~f.col("isHighQualityPlof"), 0 - ), - ) - .select( - "variantId", - "chromosome", - "geneId", - "isHighQualityPlof", - f.col("score"), - f.lit("vep").alias("datatypeId"), - f.lit("loftee").alias("datasourceId"), - ) - ), - _schema=V2G.get_schema(), + return ( + self.df.select("variantId", f.explode("transcriptConsequences").alias("tc")) + .filter(f.col("tc.lofteePrediction").isNotNull()) + .withColumn( + "isHighQualityPlof", + f.when(f.col("tc.lofteePrediction") == "HC", True).when( + f.col("tc.lofteePrediction") == "LC", False + ), + ) + .select( + "variantId", + f.col("tc.targetId"), + f.col("tc.lofteePrediction"), + "isHighQualityPlof", + ) ) - def get_most_severe_transcript_consequence( + def get_most_severe_gene_consequence( self: VariantIndex, + *, vep_consequences: DataFrame, - gene_index: GeneIndex, - ) -> V2G: - """Creates a dataset with variant to gene assignments based on VEP's predicted consequence of the transcript. - - Optionally the trancript consequences can be reduced to the universe of a gene index. + ) -> DataFrame: + """Returns a dataframe with the most severe consequence for a variant/gene pair. Args: vep_consequences (DataFrame): A dataframe of VEP consequences - gene_index (GeneIndex): A gene index to filter by. Defaults to None. Returns: - V2G: High and medium severity variant to gene assignments + DataFrame: A dataframe with the most severe consequence (plus a severity score) for a variant/gene pair """ - return V2G( - _df=self.get_transcript_consequence_df(gene_index) + return ( + self.df.select("variantId", f.explode("transcriptConsequences").alias("tc")) .select( "variantId", - "chromosome", - f.col("transcriptConsequence.targetId").alias("geneId"), - f.explode( - "transcriptConsequence.variantFunctionalConsequenceIds" - ).alias("variantFunctionalConsequenceId"), - f.lit("vep").alias("datatypeId"), - f.lit("variantConsequence").alias("datasourceId"), + f.col("tc.targetId"), + f.explode(f.col("tc.variantFunctionalConsequenceIds")).alias( + "variantFunctionalConsequenceId" + ), ) .join( - f.broadcast(vep_consequences), + # TODO: make this table a project config + f.broadcast( + vep_consequences.selectExpr( + "variantFunctionalConsequenceId", "score as severityScore" + ) + ), on="variantFunctionalConsequenceId", how="inner", ) - .drop("label") - .filter(f.col("score") != 0) - # A variant can have multiple predicted consequences on a transcript, the most severe one is selected + .filter(f.col("severityScore").isNull()) .transform( + # A variant can have multiple predicted consequences on a transcript, the most severe one is selected lambda df: get_record_with_maximum_value( - df, ["variantId", "geneId"], "score" + df, ["variantId", "targetId"], "severityScore" ) - ), - _schema=V2G.get_schema(), + ) + .withColumnRenamed( + "variantFunctionalConsequenceId", + "mostSevereVariantFunctionalConsequenceId", + ) ) diff --git a/src/gentropy/datasource/ensembl/vep_parser.py b/src/gentropy/datasource/ensembl/vep_parser.py index d84b58407..e3e36140d 100644 --- a/src/gentropy/datasource/ensembl/vep_parser.py +++ b/src/gentropy/datasource/ensembl/vep_parser.py @@ -3,7 +3,7 @@ from __future__ import annotations import importlib.resources as pkg_resources -from typing import TYPE_CHECKING, List +from typing import TYPE_CHECKING import pandas as pd from pyspark.sql import SparkSession @@ -529,14 +529,14 @@ def _collect_uniprot_accessions(trembl: Column, swissprot: Column) -> Column: ) @staticmethod - def _parse_variant_location_id(vep_input_field: Column) -> List[Column]: + def _parse_variant_location_id(vep_input_field: Column) -> list[Column]: r"""Parse variant identifier, chromosome, position, reference allele and alternate allele from VEP input field. Args: vep_input_field (Column): Column containing variant vcf string used as VEP input. Returns: - List[Column]: List of columns containing chromosome, position, reference allele and alternate allele. + list[Column]: List of columns containing chromosome, position, reference allele and alternate allele. """ variant_fields = f.split(vep_input_field, r"\t") return [ diff --git a/src/gentropy/datasource/open_targets/l2g_gold_standard.py b/src/gentropy/datasource/open_targets/l2g_gold_standard.py index 26d5a0253..21edcc201 100644 --- a/src/gentropy/datasource/open_targets/l2g_gold_standard.py +++ b/src/gentropy/datasource/open_targets/l2g_gold_standard.py @@ -9,7 +9,7 @@ from gentropy.dataset.l2g_gold_standard import L2GGoldStandard from gentropy.dataset.study_locus import StudyLocus -from gentropy.dataset.v2g import V2G +from gentropy.dataset.variant_index import VariantIndex class OpenTargetsL2GGoldStandard: @@ -60,7 +60,9 @@ def parse_positive_curation( @classmethod def expand_gold_standard_with_negatives( - cls: Type[OpenTargetsL2GGoldStandard], positive_set: DataFrame, v2g: V2G + cls: Type[OpenTargetsL2GGoldStandard], + positive_set: DataFrame, + variant_index: VariantIndex, ) -> DataFrame: """Create full set of positive and negative evidence of locus to gene associations. @@ -68,7 +70,7 @@ def expand_gold_standard_with_negatives( Args: positive_set (DataFrame): Positive set from curation - v2g (V2G): Variant to gene dataset to bring distance between a variant and a gene's TSS + variant_index (VariantIndex): Variant index to get distance to gene Returns: DataFrame: Full set of positive and negative evidence of locus to gene associations @@ -76,9 +78,13 @@ def expand_gold_standard_with_negatives( return ( positive_set.withColumnRenamed("geneId", "curated_geneId") .join( - v2g.df.selectExpr( - "variantId", "geneId as non_curated_geneId", "distance" - ).filter(f.col("distance") <= cls.LOCUS_TO_GENE_WINDOW), + variant_index.get_distance_to_gene() + .selectExpr( + "variantId", + "targetId as non_curated_geneId", + "distanceFromTss", + ) + .filter(f.col("distanceFromTss") <= cls.LOCUS_TO_GENE_WINDOW), on="variantId", how="left", ) @@ -86,7 +92,7 @@ def expand_gold_standard_with_negatives( "goldStandardSet", f.when( (f.col("curated_geneId") == f.col("non_curated_geneId")) - # to keep the positives that are outside the v2g dataset + # to keep the positives that are not part of the variant index | (f.col("non_curated_geneId").isNull()), f.lit(L2GGoldStandard.GS_POSITIVE_LABEL), ).otherwise(L2GGoldStandard.GS_NEGATIVE_LABEL), @@ -98,27 +104,27 @@ def expand_gold_standard_with_negatives( f.col("curated_geneId"), ).otherwise(f.col("non_curated_geneId")), ) - .drop("distance", "curated_geneId", "non_curated_geneId") + .drop("distanceFromTss", "curated_geneId", "non_curated_geneId") ) @classmethod def as_l2g_gold_standard( cls: type[OpenTargetsL2GGoldStandard], gold_standard_curation: DataFrame, - v2g: V2G, + variant_index: VariantIndex, ) -> L2GGoldStandard: """Initialise L2GGoldStandard from source dataset. Args: gold_standard_curation (DataFrame): Gold standard curation dataframe, extracted from https://github.com/opentargets/genetics-gold-standards - v2g (V2G): Variant to gene dataset to bring distance between a variant and a gene's TSS + variant_index (VariantIndex): Dataset to bring distance between a variant and a gene's footprint Returns: L2GGoldStandard: L2G Gold Standard dataset. False negatives have not yet been removed. """ return L2GGoldStandard( _df=cls.parse_positive_curation(gold_standard_curation).transform( - cls.expand_gold_standard_with_negatives, v2g + cls.expand_gold_standard_with_negatives, variant_index ), _schema=L2GGoldStandard.get_schema(), ) diff --git a/src/gentropy/l2g.py b/src/gentropy/l2g.py index ff8c6c8ff..9b9b7aa90 100644 --- a/src/gentropy/l2g.py +++ b/src/gentropy/l2g.py @@ -17,7 +17,7 @@ from gentropy.dataset.l2g_prediction import L2GPrediction from gentropy.dataset.study_index import StudyIndex from gentropy.dataset.study_locus import StudyLocus -from gentropy.dataset.v2g import V2G +from gentropy.dataset.variant_index import VariantIndex from gentropy.method.l2g.feature_factory import L2GFeatureInputLoader from gentropy.method.l2g.model import LocusToGeneModel from gentropy.method.l2g.trainer import LocusToGeneTrainer @@ -38,7 +38,7 @@ def __init__( model_path: str | None = None, credible_set_path: str, gold_standard_curation_path: str | None = None, - variant_gene_path: str | None = None, + variant_index_path: str | None = None, colocalisation_path: str | None = None, study_index_path: str | None = None, gene_interactions_path: str | None = None, @@ -59,7 +59,7 @@ def __init__( model_path (str | None): Path to the model. It can be either in the filesystem or the name on the Hugging Face Hub (in the form of username/repo_name). credible_set_path (str): Path to the credible set dataset necessary to build the feature matrix gold_standard_curation_path (str | None): Path to the gold standard curation file - variant_gene_path (str | None): Path to the variant-gene dataset + variant_index_path (str | None): Path to the variant index dataset colocalisation_path (str | None): Path to the colocalisation dataset study_index_path (str | None): Path to the study index dataset gene_interactions_path (str | None): Path to the gene interactions dataset @@ -96,8 +96,10 @@ def __init__( if study_index_path else None ) - self.v2g = ( - V2G.from_parquet(session, variant_gene_path) if variant_gene_path else None + self.variant_index = ( + VariantIndex.from_parquet(session, variant_index_path) + if variant_index_path + else None ) self.coloc = ( Colocalisation.from_parquet( @@ -107,7 +109,7 @@ def __init__( else None ) self.features_input_loader = L2GFeatureInputLoader( - v2g=self.v2g, + variant_index=self.variant_index, coloc=self.coloc, studies=self.studies, study_locus=self.credible_set, @@ -134,7 +136,7 @@ def run_predict(self) -> None: Raises: ValueError: If not all dependencies in prediction mode are set """ - if self.studies and self.v2g and self.coloc: + if self.studies and self.coloc: predictions = L2GPrediction.from_credible_set( self.session, self.credible_set, @@ -157,9 +159,9 @@ def run_train(self) -> None: if ( self.gs_curation and self.interactions - and self.v2g and self.wandb_run_name and self.model_path + and self.variant_index ): wandb_key = access_gcp_secret("wandb-key", "open-targets-genetics-dev") # Process gold standard and L2G features @@ -203,25 +205,30 @@ def _generate_feature_matrix(self, write_feature_matrix: bool) -> L2GFeatureMatr Raises: ValueError: If write_feature_matrix is set to True but a path is not provided or if dependencies to build features are not set. """ - if self.gs_curation and self.interactions and self.v2g: + if ( + self.gs_curation + and self.interactions + and self.studies + and self.variant_index + ): study_locus_overlap = StudyLocus( _df=self.credible_set.df.join( f.broadcast( - self.gs_curation - .withColumn( + self.gs_curation.withColumn( "variantId", f.concat_ws( - "_", - f.col("sentinel_variant.locus_GRCh38.chromosome"), - f.col("sentinel_variant.locus_GRCh38.position"), - f.col("sentinel_variant.alleles.reference"), - f.col("sentinel_variant.alleles.alternative"), - ) - ) - .select( + "_", + f.col("sentinel_variant.locus_GRCh38.chromosome"), + f.col("sentinel_variant.locus_GRCh38.position"), + f.col("sentinel_variant.alleles.reference"), + f.col("sentinel_variant.alleles.alternative"), + ), + ).select( StudyLocus.assign_study_locus_id( - ["association_info.otg_id", # studyId - "variantId"] + [ + "association_info.otg_id", # studyId + "variantId", + ] ), ) ), @@ -233,7 +240,7 @@ def _generate_feature_matrix(self, write_feature_matrix: bool) -> L2GFeatureMatr gold_standards = L2GGoldStandard.from_otg_curation( gold_standard_curation=self.gs_curation, - v2g=self.v2g, + variant_index=self.variant_index, study_locus_overlap=study_locus_overlap, interactions=self.interactions, ) diff --git a/src/gentropy/method/l2g/feature_factory.py b/src/gentropy/method/l2g/feature_factory.py index c0f0ef9b4..41084277f 100644 --- a/src/gentropy/method/l2g/feature_factory.py +++ b/src/gentropy/method/l2g/feature_factory.py @@ -4,17 +4,35 @@ from typing import Any, Iterator, Mapping -from gentropy.dataset.l2g_feature import ( +from gentropy.dataset.l2g_features.colocalisation import ( EQtlColocClppMaximumFeature, + EQtlColocClppMaximumNeighbourhoodFeature, EQtlColocH4MaximumFeature, - L2GFeature, + EQtlColocH4MaximumNeighbourhoodFeature, PQtlColocClppMaximumFeature, + PQtlColocClppMaximumNeighbourhoodFeature, PQtlColocH4MaximumFeature, + PQtlColocH4MaximumNeighbourhoodFeature, SQtlColocClppMaximumFeature, + SQtlColocClppMaximumNeighbourhoodFeature, SQtlColocH4MaximumFeature, + SQtlColocH4MaximumNeighbourhoodFeature, TuQtlColocClppMaximumFeature, + TuQtlColocClppMaximumNeighbourhoodFeature, TuQtlColocH4MaximumFeature, + TuQtlColocH4MaximumNeighbourhoodFeature, ) +from gentropy.dataset.l2g_features.distance import ( + DistanceFootprintMeanFeature, + DistanceFootprintMeanNeighbourhoodFeature, + DistanceSentinelFootprintFeature, + DistanceSentinelFootprintNeighbourhoodFeature, + DistanceSentinelTssFeature, + DistanceSentinelTssNeighbourhoodFeature, + DistanceTssMeanFeature, + DistanceTssMeanNeighbourhoodFeature, +) +from gentropy.dataset.l2g_features.l2g_feature import L2GFeature from gentropy.dataset.l2g_gold_standard import L2GGoldStandard from gentropy.dataset.study_locus import StudyLocus @@ -75,16 +93,30 @@ class FeatureFactory: """Factory class for creating features.""" feature_mapper: Mapping[str, type[L2GFeature]] = { - # "distanceTssMinimum": DistanceTssMinimumFeature, - # "distanceTssMean": DistanceTssMeanFeature, + "distanceSentinelTss": DistanceSentinelTssFeature, + "distanceSentinelTssNeighbourhood": DistanceSentinelTssNeighbourhoodFeature, + "distanceSentinelFootprint": DistanceSentinelFootprintFeature, + "distanceSentinelFootprintNeighbourhood": DistanceSentinelFootprintNeighbourhoodFeature, + "distanceTssMean": DistanceTssMeanFeature, + "distanceTssMeanNeighbourhood": DistanceTssMeanNeighbourhoodFeature, + "distanceFootprintMean": DistanceFootprintMeanFeature, + "distanceFootprintMeanNeighbourhood": DistanceFootprintMeanNeighbourhoodFeature, "eQtlColocClppMaximum": EQtlColocClppMaximumFeature, + "eQtlColocClppMaximumNeighbourhood": EQtlColocClppMaximumNeighbourhoodFeature, "pQtlColocClppMaximum": PQtlColocClppMaximumFeature, + "pQtlColocClppMaximumNeighbourhood": PQtlColocClppMaximumNeighbourhoodFeature, "sQtlColocClppMaximum": SQtlColocClppMaximumFeature, + "sQtlColocClppMaximumNeighbourhood": SQtlColocClppMaximumNeighbourhoodFeature, "tuQtlColocClppMaximum": TuQtlColocClppMaximumFeature, + "tuQtlColocClppMaximumNeighbourhood": TuQtlColocClppMaximumNeighbourhoodFeature, "eQtlColocH4Maximum": EQtlColocH4MaximumFeature, + "eQtlColocH4MaximumNeighbourhood": EQtlColocH4MaximumNeighbourhoodFeature, "pQtlColocH4Maximum": PQtlColocH4MaximumFeature, + "pQtlColocH4MaximumNeighbourhood": PQtlColocH4MaximumNeighbourhoodFeature, "sQtlColocH4Maximum": SQtlColocH4MaximumFeature, + "sQtlColocH4MaximumNeighbourhood": SQtlColocH4MaximumNeighbourhoodFeature, "tuQtlColocH4Maximum": TuQtlColocH4MaximumFeature, + "tuQtlColocH4MaximumNeighbourhood": TuQtlColocH4MaximumNeighbourhoodFeature, } def __init__( diff --git a/src/gentropy/variant_to_gene.py b/src/gentropy/variant_to_gene.py deleted file mode 100644 index cf21053d7..000000000 --- a/src/gentropy/variant_to_gene.py +++ /dev/null @@ -1,119 +0,0 @@ -"""Step to generate variant annotation dataset.""" - -from __future__ import annotations - -from functools import reduce - -from pyspark.sql import functions as f - -from gentropy.common.Liftover import LiftOverSpark -from gentropy.common.session import Session -from gentropy.dataset.gene_index import GeneIndex -from gentropy.dataset.intervals import Intervals -from gentropy.dataset.v2g import V2G -from gentropy.dataset.variant_index import VariantIndex - - -class V2GStep: - """Variant-to-gene (V2G) step. - - This step aims to generate a dataset that contains multiple pieces of evidence supporting the functional association of specific variants with genes. Some of the evidence types include: - - 1. Chromatin interaction experiments, e.g. Promoter Capture Hi-C (PCHi-C). - 2. In silico functional predictions, e.g. Variant Effect Predictor (VEP) from Ensembl. - 3. Distance between the variant and each gene's canonical transcription start site (TSS). - - Attributes: - session (Session): Session object. - variant_index_path (str): Input variant index path. - gene_index_path (str): Input gene index path. - vep_consequences_path (str): Input VEP consequences path. - liftover_chain_file_path (str): Path to GRCh37 to GRCh38 chain file. - liftover_max_length_difference: Maximum length difference for liftover. - max_distance (int): Maximum distance to consider. - approved_biotypes (list[str]): List of approved biotypes. - intervals (dict): Dictionary of interval sources. - v2g_path (str): Output V2G path. - """ - - def __init__( - self, - session: Session, - variant_index_path: str, - gene_index_path: str, - vep_consequences_path: str, - liftover_chain_file_path: str, - approved_biotypes: list[str], - interval_sources: dict[str, str], - v2g_path: str, - max_distance: int = 500_000, - liftover_max_length_difference: int = 100, - ) -> None: - """Run Variant-to-gene (V2G) step. - - Args: - session (Session): Session object. - variant_index_path (str): Input variant index path. - gene_index_path (str): Input gene index path. - vep_consequences_path (str): Input VEP consequences path. - liftover_chain_file_path (str): Path to GRCh37 to GRCh38 chain file. - approved_biotypes (list[str]): List of approved biotypes. - interval_sources (dict[str, str]): Dictionary of interval sources. - v2g_path (str): Output V2G path. - max_distance (int): Maximum distance to consider. - liftover_max_length_difference (int): Maximum length difference for liftover. - """ - # Read - gene_index = GeneIndex.from_parquet(session, gene_index_path) - vi = VariantIndex.from_parquet(session, variant_index_path).persist() - # Reading VEP consequence to score table and cast the score to the right type: - vep_consequences = session.spark.read.csv( - vep_consequences_path, sep="\t", header=True - ).withColumn("score", f.col("score").cast("double")) - - # Transform - lift = LiftOverSpark( - # lift over variants to hg38 - liftover_chain_file_path, - liftover_max_length_difference, - ) - gene_index_filtered = gene_index.filter_by_biotypes( - # Filter gene index by approved biotypes to define V2G gene universe - list(approved_biotypes) - ) - - intervals = Intervals( - _df=reduce( - lambda x, y: x.unionByName(y, allowMissingColumns=True), - # create interval instances by parsing each source - [ - Intervals.from_source( - session.spark, source_name, source_path, gene_index, lift - ).df - for source_name, source_path in interval_sources.items() - ], - ), - _schema=Intervals.get_schema(), - ) - v2g_datasets = [ - vi.get_distance_to_tss(gene_index_filtered, max_distance), - vi.get_most_severe_transcript_consequence( - vep_consequences, gene_index_filtered - ), - vi.get_plof_v2g(gene_index_filtered), - intervals.v2g(vi), - ] - v2g = V2G( - _df=reduce( - lambda x, y: x.unionByName(y, allowMissingColumns=True), - [dataset.df for dataset in v2g_datasets], - ).repartition("chromosome"), - _schema=V2G.get_schema(), - ) - - # Load - ( - v2g.df.write.partitionBy("chromosome") - .mode(session.write_mode) - .parquet(v2g_path) - ) diff --git a/tests/gentropy/conftest.py b/tests/gentropy/conftest.py index 21f05dcf3..a70c1a87d 100644 --- a/tests/gentropy/conftest.py +++ b/tests/gentropy/conftest.py @@ -25,7 +25,6 @@ from gentropy.dataset.study_locus import StudyLocus from gentropy.dataset.study_locus_overlap import StudyLocusOverlap from gentropy.dataset.summary_statistics import SummaryStatistics -from gentropy.dataset.v2g import V2G from gentropy.dataset.variant_index import VariantIndex from gentropy.datasource.eqtl_catalogue.finemapping import EqtlCatalogueFinemapping from gentropy.datasource.eqtl_catalogue.study_index import EqtlCatalogueStudyIndex @@ -252,29 +251,17 @@ def mock_intervals(spark: SparkSession) -> Intervals: @pytest.fixture() -def mock_v2g(spark: SparkSession) -> V2G: - """Mock v2g dataset.""" - v2g_schema = V2G.get_schema() - - data_spec = ( - dg.DataGenerator( - spark, - rows=400, - partitions=4, - randomSeedMethod="hash_fieldname", - ) - .withSchema(v2g_schema) - .withColumnSpec("distance", percentNulls=0.1) - .withColumnSpec("resourceScore", percentNulls=0.1) - .withColumnSpec("score", percentNulls=0.1) - .withColumnSpec("pmid", percentNulls=0.1) - .withColumnSpec("biofeature", percentNulls=0.1) - .withColumnSpec("variantFunctionalConsequenceId", percentNulls=0.1) - .withColumnSpec("isHighQualityPlof", percentNulls=0.1) +def mock_variant_consequence_to_score(spark: SparkSession) -> DataFrame: + """Slice of the VEP consequence to score table.""" + return spark.createDataFrame( + [ + ("SO_0001893", "transcript_ablation", 1.0), + ("SO_0001822", "inframe_deletion", 0.66), + ("SO_0001567", "stop_retained_variant", 0.33), + ], + ["variantFunctionalConsequenceId", "label", "score"], ) - return V2G(_df=data_spec.build(), _schema=v2g_schema) - @pytest.fixture() def mock_variant_index(spark: SparkSession) -> VariantIndex: @@ -386,9 +373,9 @@ def mock_summary_statistics_data(spark: SparkSession) -> DataFrame: # Allowing missingness: .withColumnSpec("standardError", percentNulls=0.1) # Making sure p-values are below 1: - ).build() + ) - return data_spec + return data_spec.build() @pytest.fixture() @@ -620,7 +607,7 @@ def mock_l2g_feature_matrix(spark: SparkSession) -> L2GFeatureMatrix: ("1", "gene1", 100.0, None), ("2", "gene2", 1000.0, 0.0), ], - "studyLocusId STRING, geneId STRING, distanceTssMean FLOAT, distanceTssMinimum FLOAT", + "studyLocusId STRING, geneId STRING, distanceTssMean FLOAT, distanceSentinelTssMinimum FLOAT", ), with_gold_standard=False, ) diff --git a/tests/gentropy/dataset/test_intervals.py b/tests/gentropy/dataset/test_intervals.py deleted file mode 100644 index 26d79acd1..000000000 --- a/tests/gentropy/dataset/test_intervals.py +++ /dev/null @@ -1,18 +0,0 @@ -"""Tests on LD index.""" - -from __future__ import annotations - -from typing import TYPE_CHECKING - -from gentropy.dataset.v2g import V2G - -if TYPE_CHECKING: - from gentropy.dataset.intervals import Intervals - from gentropy.dataset.variant_index import VariantIndex - - -def test_interval_v2g_creation( - mock_intervals: Intervals, mock_variant_index: VariantIndex -) -> None: - """Test creation of V2G from intervals.""" - assert isinstance(mock_intervals.v2g(mock_variant_index), V2G) diff --git a/tests/gentropy/dataset/test_l2g.py b/tests/gentropy/dataset/test_l2g.py index 125352f8e..f73b6f7c2 100644 --- a/tests/gentropy/dataset/test_l2g.py +++ b/tests/gentropy/dataset/test_l2g.py @@ -177,7 +177,7 @@ def test_calculate_feature_missingness_rate( spark: SparkSession, mock_l2g_feature_matrix: L2GFeatureMatrix ) -> None: """Test L2GFeatureMatrix.calculate_feature_missingness_rate.""" - expected_missingness = {"distanceTssMean": 0.0, "distanceTssMinimum": 1.0} + expected_missingness = {"distanceTssMean": 0.0, "distanceSentinelTssMinimum": 1.0} observed_missingness = mock_l2g_feature_matrix.calculate_feature_missingness_rate() assert isinstance(observed_missingness, dict) assert mock_l2g_feature_matrix.features_list is not None and len( diff --git a/tests/gentropy/dataset/test_l2g_feature.py b/tests/gentropy/dataset/test_l2g_feature.py index 82df2dd4f..18d8a4066 100644 --- a/tests/gentropy/dataset/test_l2g_feature.py +++ b/tests/gentropy/dataset/test_l2g_feature.py @@ -1,28 +1,62 @@ -"""Test L2G feature generation.""" +"""Test locus-to-gene feature generation.""" from __future__ import annotations from typing import TYPE_CHECKING, Any +import pyspark.sql.functions as f import pytest +from pyspark.sql.types import ( + ArrayType, + BooleanType, + IntegerType, + LongType, + StringType, + StructField, + StructType, +) -from gentropy.dataset.l2g_feature import ( +from gentropy.dataset.colocalisation import Colocalisation +from gentropy.dataset.l2g_features.colocalisation import ( EQtlColocClppMaximumFeature, + EQtlColocClppMaximumNeighbourhoodFeature, EQtlColocH4MaximumFeature, - L2GFeature, + EQtlColocH4MaximumNeighbourhoodFeature, PQtlColocClppMaximumFeature, + PQtlColocClppMaximumNeighbourhoodFeature, PQtlColocH4MaximumFeature, + PQtlColocH4MaximumNeighbourhoodFeature, SQtlColocClppMaximumFeature, + SQtlColocClppMaximumNeighbourhoodFeature, SQtlColocH4MaximumFeature, + SQtlColocH4MaximumNeighbourhoodFeature, TuQtlColocClppMaximumFeature, + TuQtlColocClppMaximumNeighbourhoodFeature, TuQtlColocH4MaximumFeature, + TuQtlColocH4MaximumNeighbourhoodFeature, + common_colocalisation_feature_logic, + common_neighbourhood_colocalisation_feature_logic, +) +from gentropy.dataset.l2g_features.distance import ( + DistanceFootprintMeanFeature, + DistanceFootprintMeanNeighbourhoodFeature, + DistanceSentinelFootprintFeature, + DistanceSentinelFootprintNeighbourhoodFeature, + DistanceSentinelTssFeature, + DistanceSentinelTssNeighbourhoodFeature, + DistanceTssMeanFeature, + DistanceTssMeanNeighbourhoodFeature, + common_distance_feature_logic, + common_neighbourhood_distance_feature_logic, ) +from gentropy.dataset.l2g_features.l2g_feature import L2GFeature +from gentropy.dataset.study_index import StudyIndex +from gentropy.dataset.study_locus import StudyLocus +from gentropy.dataset.variant_index import VariantIndex from gentropy.method.l2g.feature_factory import L2GFeatureInputLoader if TYPE_CHECKING: - from gentropy.dataset.colocalisation import Colocalisation - from gentropy.dataset.study_index import StudyIndex - from gentropy.dataset.study_locus import StudyLocus + from pyspark.sql import SparkSession @pytest.mark.parametrize( @@ -36,6 +70,22 @@ PQtlColocClppMaximumFeature, SQtlColocClppMaximumFeature, TuQtlColocClppMaximumFeature, + EQtlColocClppMaximumNeighbourhoodFeature, + PQtlColocClppMaximumNeighbourhoodFeature, + SQtlColocClppMaximumNeighbourhoodFeature, + TuQtlColocClppMaximumNeighbourhoodFeature, + EQtlColocH4MaximumNeighbourhoodFeature, + PQtlColocH4MaximumNeighbourhoodFeature, + SQtlColocH4MaximumNeighbourhoodFeature, + TuQtlColocH4MaximumNeighbourhoodFeature, + DistanceTssMeanFeature, + DistanceTssMeanNeighbourhoodFeature, + DistanceFootprintMeanFeature, + DistanceFootprintMeanNeighbourhoodFeature, + DistanceSentinelTssFeature, + DistanceSentinelTssNeighbourhoodFeature, + DistanceSentinelFootprintFeature, + DistanceSentinelFootprintNeighbourhoodFeature, ], ) def test_feature_factory_return_type( @@ -43,11 +93,13 @@ def test_feature_factory_return_type( mock_study_locus: StudyLocus, mock_colocalisation: Colocalisation, mock_study_index: StudyIndex, + mock_variant_index: VariantIndex, ) -> None: """Test that every feature factory returns a L2GFeature dataset.""" loader = L2GFeatureInputLoader( colocalisation=mock_colocalisation, study_index=mock_study_index, + variant_index=mock_variant_index, study_locus=mock_study_locus, ) feature_dataset = feature_class.compute( @@ -57,3 +109,380 @@ def test_feature_factory_return_type( ), ) assert isinstance(feature_dataset, L2GFeature) + + +class TestCommonColocalisationFeatureLogic: + """Test the common logic of the colocalisation features.""" + + def test__common_colocalisation_feature_logic( + self: TestCommonColocalisationFeatureLogic, + spark: SparkSession, + ) -> None: + """Test the common logic of the colocalisation features. + + The test data associates studyLocusId1 with gene1 based on the colocalisation with studyLocusId2 and studyLocusId3. + The H4 value of number 2 is higher, therefore the feature value should be based on that. + """ + feature_name = "eQtlColocH4Maximum" + observed_df = common_colocalisation_feature_logic( + self.sample_study_loci_to_annotate, + self.colocalisation_method, + self.colocalisation_metric, + feature_name, + self.qtl_type, + colocalisation=self.sample_colocalisation, + study_index=self.sample_studies, + study_locus=self.sample_study_locus, + ) + expected_df = spark.createDataFrame( + [ + { + "studyLocusId": "1", + "geneId": "gene1", + "eQtlColocH4Maximum": 0.81, + }, + { + "studyLocusId": "1", + "geneId": "gene2", + "eQtlColocH4Maximum": 0.9, + }, + ], + ).select("studyLocusId", "geneId", "eQtlColocH4Maximum") + assert ( + observed_df.collect() == expected_df.collect() + ), "The feature values are not as expected." + + def test__common_neighbourhood_colocalisation_feature_logic( + self: TestCommonColocalisationFeatureLogic, spark: SparkSession + ) -> None: + """Test the common logic of the neighbourhood colocalisation features.""" + feature_name = "eQtlColocH4MaximumNeighbourhood" + observed_df = common_neighbourhood_colocalisation_feature_logic( + self.sample_study_loci_to_annotate, + self.colocalisation_method, + self.colocalisation_metric, + feature_name, + self.qtl_type, + colocalisation=self.sample_colocalisation, + study_index=self.sample_studies, + study_locus=self.sample_study_locus, + ) + expected_df = spark.createDataFrame( + [ + { + "studyLocusId": "1", + "geneId": "gene1", + "eQtlColocH4MaximumNeighbourhood": 0.08999999999999997, + }, + { + "studyLocusId": "1", + "geneId": "gene2", + "eQtlColocH4MaximumNeighbourhood": 0.0, + }, + ], + ).select("studyLocusId", "geneId", "eQtlColocH4MaximumNeighbourhood") + assert ( + observed_df.collect() == expected_df.collect() + ), "The expected and observed dataframes do not match." + + @pytest.fixture(autouse=True) + def _setup(self: TestCommonColocalisationFeatureLogic, spark: SparkSession) -> None: + """Set up the test variables.""" + self.colocalisation_method = "Coloc" + self.colocalisation_metric = "h4" + self.qtl_type = "eqtl" + + self.sample_study_loci_to_annotate = StudyLocus( + _df=spark.createDataFrame( + [ + { + "studyLocusId": "1", + "variantId": "lead1", + "studyId": "study1", # this is a GWAS + "chromosome": "1", + }, + ] + ), + _schema=StudyLocus.get_schema(), + ) + self.sample_colocalisation = Colocalisation( + _df=spark.createDataFrame( + [ + { + "leftStudyLocusId": "1", + "rightStudyLocusId": "2", + "chromosome": "1", + "colocalisationMethod": "COLOC", + "numberColocalisingVariants": 1, + "h4": 0.81, + "rightStudyType": "eqtl", + }, + { + "leftStudyLocusId": "1", + "rightStudyLocusId": "3", # qtl linked to the same gene as studyLocusId 2 with a lower score + "chromosome": "1", + "colocalisationMethod": "COLOC", + "numberColocalisingVariants": 1, + "h4": 0.50, + "rightStudyType": "eqtl", + }, + { + "leftStudyLocusId": "1", + "rightStudyLocusId": "4", # qtl linked to a diff gene and with the highest score + "chromosome": "1", + "colocalisationMethod": "COLOC", + "numberColocalisingVariants": 1, + "h4": 0.90, + "rightStudyType": "eqtl", + }, + ], + schema=Colocalisation.get_schema(), + ), + _schema=Colocalisation.get_schema(), + ) + self.sample_study_locus = StudyLocus( + _df=spark.createDataFrame( + [ + { + "studyLocusId": "1", + "variantId": "lead1", + "studyId": "study1", # this is a GWAS + "chromosome": "1", + }, + { + "studyLocusId": "2", + "variantId": "lead1", + "studyId": "study2", # this is a QTL (same gee) + "chromosome": "1", + }, + { + "studyLocusId": "3", + "variantId": "lead1", + "studyId": "study3", # this is another QTL (same gene) + "chromosome": "1", + }, + { + "studyLocusId": "4", + "variantId": "lead1", + "studyId": "study4", # this is another QTL (diff gene) + "chromosome": "1", + }, + ] + ), + _schema=StudyLocus.get_schema(), + ) + self.sample_studies = StudyIndex( + _df=spark.createDataFrame( + [ + { + "studyId": "study1", + "studyType": "gwas", + "geneId": None, + "traitFromSource": "trait1", + "projectId": "project1", + }, + { + "studyId": "study2", + "studyType": "eqtl", + "geneId": "gene1", + "traitFromSource": "trait2", + "projectId": "project2", + }, + { + "studyId": "study3", + "studyType": "eqtl", + "geneId": "gene1", + "traitFromSource": "trait3", + "projectId": "project3", + }, + { + "studyId": "study4", + "studyType": "eqtl", + "geneId": "gene2", + "traitFromSource": "trait4", + "projectId": "project4", + }, + ] + ), + _schema=StudyIndex.get_schema(), + ) + + +class TestCommonDistanceFeatureLogic: + """Test the CommonDistanceFeatureLogic methods.""" + + @pytest.mark.parametrize( + ("feature_name", "expected_data"), + [ + ( + "distanceSentinelTss", + [ + { + "studyLocusId": "1", + "geneId": "gene1", + "distanceSentinelTss": 0.0, + }, + { + "studyLocusId": "1", + "geneId": "gene2", + "distanceSentinelTss": 0.95, + }, + ], + ), + ( + "distanceTssMean", + [ + {"studyLocusId": "1", "geneId": "gene1", "distanceTssMean": 0.09}, + {"studyLocusId": "1", "geneId": "gene2", "distanceTssMean": 0.65}, + ], + ), + ], + ) + def test_common_distance_feature_logic( + self: TestCommonDistanceFeatureLogic, + spark: SparkSession, + feature_name: str, + expected_data: dict[str, Any], + ) -> None: + """Test the logic of the function that extracts features from distance. + + 2 tests: + - distanceSentinelTss: distance of the sentinel is 10, the max distance is 10. In log scale, the score is 0. + - distanceTssMean: avg distance of any variant in the credible set, weighted by its posterior. + """ + observed_df = ( + common_distance_feature_logic( + self.sample_study_locus, + variant_index=self.sample_variant_index, + feature_name=feature_name, + distance_type=self.distance_type, + genomic_window=10, + ) + .withColumn(feature_name, f.round(f.col(feature_name), 2)) + .orderBy(feature_name) + ) + expected_df = ( + spark.createDataFrame(expected_data) + .select("studyLocusId", "geneId", feature_name) + .orderBy(feature_name) + ) + assert ( + observed_df.collect() == expected_df.collect() + ), f"Expected and observed dataframes are not equal for feature {feature_name}." + + def test_common_neighbourhood_colocalisation_feature_logic( + self: TestCommonDistanceFeatureLogic, + spark: SparkSession, + ) -> None: + """Test the logic of the function that extracts the distance between the sentinel of a credible set and the nearby genes.""" + feature_name = "distanceSentinelTssNeighbourhood" + observed_df = ( + common_neighbourhood_distance_feature_logic( + self.sample_study_locus, + variant_index=self.sample_variant_index, + feature_name=feature_name, + distance_type=self.distance_type, + genomic_window=10, + ) + .withColumn(feature_name, f.round(f.col(feature_name), 2)) + .orderBy(f.col(feature_name).asc()) + ) + expected_df = spark.createDataFrame( + (["1", "gene1", -0.48], ["1", "gene2", 0.48]), + ["studyLocusId", "geneId", feature_name], + ).orderBy(feature_name) + assert ( + observed_df.collect() == expected_df.collect() + ), "Output doesn't meet the expectation." + + @pytest.fixture(autouse=True) + def _setup(self: TestCommonDistanceFeatureLogic, spark: SparkSession) -> None: + """Set up testing fixtures.""" + self.distance_type = "distanceFromTss" + self.sample_study_locus = StudyLocus( + _df=spark.createDataFrame( + [ + { + "studyLocusId": "1", + "variantId": "lead1", + "studyId": "study1", + "locus": [ + { + "variantId": "lead1", + "posteriorProbability": 0.5, + }, + { + "variantId": "tag1", + "posteriorProbability": 0.5, + }, + ], + "chromosome": "1", + }, + ], + StudyLocus.get_schema(), + ), + _schema=StudyLocus.get_schema(), + ) + self.variant_index_schema = StructType( + [ + StructField("variantId", StringType(), True), + StructField("chromosome", StringType(), True), + StructField("position", IntegerType(), True), + StructField("referenceAllele", StringType(), True), + StructField("alternateAllele", StringType(), True), + StructField( + "transcriptConsequences", + ArrayType( + StructType( + [ + StructField("distanceFromTss", LongType(), True), + StructField("targetId", StringType(), True), + StructField("isEnsemblCanonical", BooleanType(), True), + ] + ) + ), + True, + ), + ] + ) + self.sample_variant_index = VariantIndex( + _df=spark.createDataFrame( + [ + ( + "lead1", + "chrom", + 1, + "A", + "T", + [ + { + "distanceFromTss": 10, + "targetId": "gene1", + "isEnsemblCanonical": True, + }, + { + "distanceFromTss": 2, + "targetId": "gene2", + "isEnsemblCanonical": True, + }, + ], + ), + ( + "tag1", + "chrom", + 1, + "A", + "T", + [ + { + "distanceFromTss": 5, + "targetId": "gene1", + "isEnsemblCanonical": True, + }, + ], + ), + ], + self.variant_index_schema, + ), + _schema=VariantIndex.get_schema(), + ) diff --git a/tests/gentropy/dataset/test_v2g.py b/tests/gentropy/dataset/test_v2g.py deleted file mode 100644 index 24a917508..000000000 --- a/tests/gentropy/dataset/test_v2g.py +++ /dev/null @@ -1,23 +0,0 @@ -"""Tests V2G dataset.""" - -from __future__ import annotations - -from typing import TYPE_CHECKING - -from gentropy.dataset.v2g import V2G - -if TYPE_CHECKING: - from gentropy.dataset.gene_index import GeneIndex - - -def test_v2g_creation(mock_v2g: V2G) -> None: - """Test v2g creation with mock data.""" - assert isinstance(mock_v2g, V2G) - - -def test_v2g_filter_by_genes(mock_v2g: V2G, mock_gene_index: GeneIndex) -> None: - """Test v2g filter by genes.""" - assert isinstance( - mock_v2g.filter_by_genes(mock_gene_index), - V2G, - ) diff --git a/tests/gentropy/dataset/test_variant_index.py b/tests/gentropy/dataset/test_variant_index.py index 12afba89f..29a6ef035 100644 --- a/tests/gentropy/dataset/test_variant_index.py +++ b/tests/gentropy/dataset/test_variant_index.py @@ -8,12 +8,10 @@ from pyspark.sql import functions as f from pyspark.sql import types as t -from gentropy.dataset.gene_index import GeneIndex -from gentropy.dataset.v2g import V2G from gentropy.dataset.variant_index import VariantIndex if TYPE_CHECKING: - from pyspark.sql import SparkSession + from pyspark.sql import DataFrame, SparkSession def test_variant_index_creation(mock_variant_index: VariantIndex) -> None: @@ -21,20 +19,6 @@ def test_variant_index_creation(mock_variant_index: VariantIndex) -> None: assert isinstance(mock_variant_index, VariantIndex) -def test_get_plof_v2g( - mock_variant_index: VariantIndex, mock_gene_index: GeneIndex -) -> None: - """Test get_plof_v2g with mock variant annotation.""" - assert isinstance(mock_variant_index.get_plof_v2g(mock_gene_index), V2G) - - -def test_get_distance_to_tss( - mock_variant_index: VariantIndex, mock_gene_index: GeneIndex -) -> None: - """Test get_distance_to_tss with mock variant annotation.""" - assert isinstance(mock_variant_index.get_distance_to_tss(mock_gene_index), V2G) - - class TestVariantIndex: """Collection of tests around the functionality and shape of the variant index.""" @@ -147,3 +131,47 @@ def test_rsid_column_updated(self: TestVariantIndex) -> None: .count() == 2 ) + + @pytest.mark.parametrize( + "distance_type", ["distanceFromTss", "distanceFromFootprint"] + ) + def test_get_distance_to_gene( + self: TestVariantIndex, mock_variant_index: VariantIndex, distance_type: str + ) -> None: + """Assert that the function returns a df with the requested columns.""" + expected_cols = ["variantId", "targetId", distance_type] + observed = mock_variant_index.get_distance_to_gene(distance_type=distance_type) + for col in expected_cols: + assert col in observed.columns, f"Column {col} not in {observed.columns}" + + def test_get_most_severe_gene_consequence( + self: TestVariantIndex, + mock_variant_index: VariantIndex, + mock_variant_consequence_to_score: DataFrame, + ) -> None: + """Assert that the function returns a df with the requested columns.""" + expected_cols = [ + "variantId", + "targetId", + "mostSevereVariantFunctionalConsequenceId", + "severityScore", + ] + observed = mock_variant_index.get_most_severe_gene_consequence( + vep_consequences=mock_variant_consequence_to_score + ) + for col in expected_cols: + assert col in observed.columns, f"Column {col} not in {observed.columns}" + + def test_get_loftee( + self: TestVariantIndex, mock_variant_index: VariantIndex + ) -> None: + """Assert that the function returns a df with the requested columns.""" + expected_cols = [ + "variantId", + "targetId", + "lofteePrediction", + "isHighQualityPlof", + ] + observed = mock_variant_index.get_loftee() + for col in expected_cols: + assert col in observed.columns, f"Column {col} not in {observed.columns}" diff --git a/tests/gentropy/datasource/open_targets/test_l2g_gold_standard.py b/tests/gentropy/datasource/open_targets/test_l2g_gold_standard.py index 78f97d48f..aa36359ca 100644 --- a/tests/gentropy/datasource/open_targets/test_l2g_gold_standard.py +++ b/tests/gentropy/datasource/open_targets/test_l2g_gold_standard.py @@ -6,11 +6,20 @@ import pytest from pyspark.sql import DataFrame +from pyspark.sql.types import ( + ArrayType, + BooleanType, + IntegerType, + LongType, + StringType, + StructField, + StructType, +) from gentropy.dataset.l2g_feature_matrix import L2GFeatureMatrix from gentropy.dataset.l2g_gold_standard import L2GGoldStandard from gentropy.dataset.study_index import StudyIndex -from gentropy.dataset.v2g import V2G +from gentropy.dataset.variant_index import VariantIndex from gentropy.datasource.open_targets.l2g_gold_standard import ( OpenTargetsL2GGoldStandard, ) @@ -25,13 +34,13 @@ def test_open_targets_as_l2g_gold_standard( sample_l2g_gold_standard: DataFrame, - mock_v2g: V2G, + mock_variant_index: VariantIndex, ) -> None: """Test L2G gold standard from OTG curation.""" assert isinstance( OpenTargetsL2GGoldStandard.as_l2g_gold_standard( sample_l2g_gold_standard, - mock_v2g, + mock_variant_index, ), L2GGoldStandard, ) @@ -81,19 +90,52 @@ def _setup(self: TestExpandGoldStandardWithNegatives, spark: SparkSession) -> No ["variantId", "geneId", "studyId"], ) - sample_v2g_df = spark.createDataFrame( - [ - ("variant1", "gene1", 5, "X", "X", "X"), - ("variant1", "gene3", 10, "X", "X", "X"), - ], + sample_variant_index_df = spark.createDataFrame( [ - "variantId", - "geneId", - "distance", - "chromosome", - "datatypeId", - "datasourceId", + ( + "variant1", + "chrom", + 1, + "A", + "T", + [ + { + "distanceFromTss": 5, + "targetId": "gene1", + "isEnsemblCanonical": True, + }, + { + "distanceFromTss": 10, + "targetId": "gene3", + "isEnsemblCanonical": True, + }, + ], + ), ], + StructType( + [ + StructField("variantId", StringType(), True), + StructField("chromosome", StringType(), True), + StructField("position", IntegerType(), True), + StructField("referenceAllele", StringType(), True), + StructField("alternateAllele", StringType(), True), + StructField( + "transcriptConsequences", + ArrayType( + StructType( + [ + StructField("distanceFromTss", LongType(), True), + StructField("targetId", StringType(), True), + StructField( + "isEnsemblCanonical", BooleanType(), True + ), + ] + ) + ), + True, + ), + ] + ), ) self.expected_expanded_gs = spark.createDataFrame( @@ -107,7 +149,9 @@ def _setup(self: TestExpandGoldStandardWithNegatives, spark: SparkSession) -> No self.observed_df = ( OpenTargetsL2GGoldStandard.expand_gold_standard_with_negatives( self.sample_positive_set, - V2G(_df=sample_v2g_df, _schema=V2G.get_schema()), + VariantIndex( + _df=sample_variant_index_df, _schema=VariantIndex.get_schema() + ), ) ) diff --git a/tests/gentropy/test_schemas.py b/tests/gentropy/test_schemas.py index 6840e3207..1b06076d0 100644 --- a/tests/gentropy/test_schemas.py +++ b/tests/gentropy/test_schemas.py @@ -18,7 +18,7 @@ from _pytest.fixtures import FixtureRequest from gentropy.dataset.gene_index import GeneIndex - from gentropy.dataset.v2g import V2G + from gentropy.dataset.l2g_prediction import L2GPrediction SCHEMA_DIR = "src/gentropy/assets/schemas" @@ -75,21 +75,23 @@ def test_schema_columns_camelcase(schema_json: str) -> None: class TestValidateSchema: - """Test validate_schema method using V2G (unnested) and GeneIndex (nested) as a testing dataset.""" + """Test validate_schema method using L2GPrediction (unnested) and GeneIndex (nested) as a testing dataset.""" @pytest.fixture() def mock_dataset_instance( self: TestValidateSchema, request: FixtureRequest - ) -> V2G | GeneIndex: + ) -> L2GPrediction | GeneIndex: """Meta fixture to return the value of any requested fixture.""" return request.getfixturevalue(request.param) @pytest.mark.parametrize( - "mock_dataset_instance", ["mock_v2g", "mock_gene_index"], indirect=True + "mock_dataset_instance", + ["mock_l2g_predictions", "mock_gene_index"], + indirect=True, ) def test_validate_schema_extra_field( self: TestValidateSchema, - mock_dataset_instance: V2G | GeneIndex, + mock_dataset_instance: L2GPrediction | GeneIndex, ) -> None: """Test that validate_schema raises an error if the observed schema has an extra field.""" with pytest.raises(SchemaValidationError, match="extraField"): @@ -98,22 +100,26 @@ def test_validate_schema_extra_field( ) @pytest.mark.parametrize( - "mock_dataset_instance", ["mock_v2g", "mock_gene_index"], indirect=True + "mock_dataset_instance", + ["mock_l2g_predictions", "mock_gene_index"], + indirect=True, ) def test_validate_schema_missing_field( self: TestValidateSchema, - mock_dataset_instance: V2G | GeneIndex, + mock_dataset_instance: L2GPrediction | GeneIndex, ) -> None: """Test that validate_schema raises an error if the observed schema is missing a required field, geneId in this case.""" with pytest.raises(SchemaValidationError, match="geneId"): mock_dataset_instance.df = mock_dataset_instance.df.drop("geneId") @pytest.mark.parametrize( - "mock_dataset_instance", ["mock_v2g", "mock_gene_index"], indirect=True + "mock_dataset_instance", + ["mock_l2g_predictions", "mock_gene_index"], + indirect=True, ) def test_validate_schema_duplicated_field( self: TestValidateSchema, - mock_dataset_instance: V2G | GeneIndex, + mock_dataset_instance: L2GPrediction | GeneIndex, ) -> None: """Test that validate_schema raises an error if the observed schema has a duplicated field, geneId in this case.""" with pytest.raises(SchemaValidationError, match="geneId"): @@ -122,11 +128,13 @@ def test_validate_schema_duplicated_field( ) @pytest.mark.parametrize( - "mock_dataset_instance", ["mock_v2g", "mock_gene_index"], indirect=True + "mock_dataset_instance", + ["mock_l2g_predictions", "mock_gene_index"], + indirect=True, ) def test_validate_schema_different_datatype( self: TestValidateSchema, - mock_dataset_instance: V2G | GeneIndex, + mock_dataset_instance: L2GPrediction | GeneIndex, ) -> None: """Test that validate_schema raises an error if any field in the observed schema has a different type than expected.""" with pytest.raises(SchemaValidationError, match="geneId"):