From 793a58bce9a2c4b029556a0142f9420954a0ba54 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Irene=20L=C3=B3pez?= <45119610+ireneisdoomed@users.noreply.github.com> Date: Fri, 1 Dec 2023 11:48:29 +0000 Subject: [PATCH 1/2] fix: correct and test study splitter when subStudyDescription is the same (#289) --- .../datasource/gwas_catalog/study_splitter.py | 2 +- .../test_gwas_catalog_study_splitter.py | 70 +++++++++++++++++++ 2 files changed, 71 insertions(+), 1 deletion(-) diff --git a/src/otg/datasource/gwas_catalog/study_splitter.py b/src/otg/datasource/gwas_catalog/study_splitter.py index 91e874ee0..339afb068 100644 --- a/src/otg/datasource/gwas_catalog/study_splitter.py +++ b/src/otg/datasource/gwas_catalog/study_splitter.py @@ -74,7 +74,7 @@ def _resolve_study_id(study_id: Column, sub_study_description: Column) -> Column """ split_w = Window.partitionBy(study_id).orderBy(sub_study_description) row_number = f.dense_rank().over(split_w) - substudy_count = f.count(row_number).over(split_w) + substudy_count = f.approx_count_distinct(row_number).over(split_w) return f.when(substudy_count == 1, study_id).otherwise( f.concat_ws("_", study_id, row_number) ) diff --git a/tests/datasource/gwas_catalog/test_gwas_catalog_study_splitter.py b/tests/datasource/gwas_catalog/test_gwas_catalog_study_splitter.py index a79f9eb4c..6ad0e71e3 100644 --- a/tests/datasource/gwas_catalog/test_gwas_catalog_study_splitter.py +++ b/tests/datasource/gwas_catalog/test_gwas_catalog_study_splitter.py @@ -1,10 +1,18 @@ """Tests GWAS Catalog study splitter.""" from __future__ import annotations +from typing import TYPE_CHECKING, Any + +import pyspark.sql.functions as f +import pytest + from otg.datasource.gwas_catalog.associations import GWASCatalogAssociations from otg.datasource.gwas_catalog.study_index import GWASCatalogStudyIndex from otg.datasource.gwas_catalog.study_splitter import GWASCatalogStudySplitter +if TYPE_CHECKING: + from pyspark.sql import SparkSession + def test_gwas_catalog_splitter_split( mock_study_index_gwas_catalog: GWASCatalogStudyIndex, @@ -17,3 +25,65 @@ def test_gwas_catalog_splitter_split( assert isinstance(d1, GWASCatalogStudyIndex) assert isinstance(d2, GWASCatalogAssociations) + + +@pytest.mark.parametrize( + "observed, expected", + [ + # Test 1 - it shouldn't split + ( + # observed - 2 associations with the same subStudy annotation + [ + ( + "varA", + "GCST003436", + "Endometrial cancer|no_pvalue_text|EFO_1001512", + ), + ( + "varB", + "GCST003436", + "Endometrial cancer|no_pvalue_text|EFO_1001512", + ), + ], + # expected - 2 associations with the same unsplit updatedStudyId + [ + ("GCST003436",), + ("GCST003436",), + ], + ), + # Test 2 - it should split + ( + # observed - 2 associations with the different subStudy annotation + [ + ( + "varA", + "GCST003436", + "Endometrial cancer|no_pvalue_text|EFO_1001512", + ), + ( + "varB", + "GCST003436", + "Uterine carcinoma|no_pvalue_text|EFO_0002919", + ), + ], + # expected - 2 associations with the same unsplit updatedStudyId + [ + ("GCST003436",), + ("GCST003436_2",), + ], + ), + ], +) +def test__resolve_study_id( + spark: SparkSession, observed: list[Any], expected: list[Any] +) -> None: + """Test _resolve_study_id.""" + observed_df = spark.createDataFrame( + observed, schema=["variantId", "studyId", "subStudyDescription"] + ).select( + GWASCatalogStudySplitter._resolve_study_id( + f.col("studyId"), f.col("subStudyDescription").alias("updatedStudyId") + ) + ) + expected_df = spark.createDataFrame(expected, schema=["updatedStudyId"]) + assert observed_df.collect() == expected_df.collect() From a6cc21dc21a3ff30f0d1c106025e47e0063cea79 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Irene=20L=C3=B3pez?= <45119610+ireneisdoomed@users.noreply.github.com> Date: Fri, 1 Dec 2023 15:42:53 +0000 Subject: [PATCH 2/2] fix(clump): read input files recursively (#292) --- src/otg/clump.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/otg/clump.py b/src/otg/clump.py index 8cb756fa0..c8b587126 100644 --- a/src/otg/clump.py +++ b/src/otg/clump.py @@ -45,7 +45,9 @@ def __post_init__(self: ClumpStep) -> None: Raises: ValueError: If study index and LD index paths are not provided for study locus. """ - input_cols = self.session.spark.read.parquet(self.input_path).columns + input_cols = self.session.spark.read.parquet( + self.input_path, recursiveFileLookup=True + ).columns if "studyLocusId" in input_cols: if self.study_index_path is None or self.ld_index_path is None: raise ValueError( @@ -59,7 +61,9 @@ def __post_init__(self: ClumpStep) -> None: study_index=study_index, ld_index=ld_index ).clump() else: - sumstats = SummaryStatistics.from_parquet(self.session, self.input_path) + sumstats = SummaryStatistics.from_parquet( + self.session, self.input_path, recursiveFileLookup=True + ) clumped_study_locus = sumstats.window_based_clumping( locus_collect_distance=self.locus_collect_distance )