Skip to content

Commit fc4b33e

Browse files
authored
Merge branch 'main' into il-clump-fix
2 parents 893ceaa + 793a58b commit fc4b33e

File tree

2 files changed

+71
-1
lines changed

2 files changed

+71
-1
lines changed

src/otg/datasource/gwas_catalog/study_splitter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def _resolve_study_id(study_id: Column, sub_study_description: Column) -> Column
7474
"""
7575
split_w = Window.partitionBy(study_id).orderBy(sub_study_description)
7676
row_number = f.dense_rank().over(split_w)
77-
substudy_count = f.count(row_number).over(split_w)
77+
substudy_count = f.approx_count_distinct(row_number).over(split_w)
7878
return f.when(substudy_count == 1, study_id).otherwise(
7979
f.concat_ws("_", study_id, row_number)
8080
)

tests/datasource/gwas_catalog/test_gwas_catalog_study_splitter.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,18 @@
11
"""Tests GWAS Catalog study splitter."""
22
from __future__ import annotations
33

4+
from typing import TYPE_CHECKING, Any
5+
6+
import pyspark.sql.functions as f
7+
import pytest
8+
49
from otg.datasource.gwas_catalog.associations import GWASCatalogAssociations
510
from otg.datasource.gwas_catalog.study_index import GWASCatalogStudyIndex
611
from otg.datasource.gwas_catalog.study_splitter import GWASCatalogStudySplitter
712

13+
if TYPE_CHECKING:
14+
from pyspark.sql import SparkSession
15+
816

917
def test_gwas_catalog_splitter_split(
1018
mock_study_index_gwas_catalog: GWASCatalogStudyIndex,
@@ -17,3 +25,65 @@ def test_gwas_catalog_splitter_split(
1725

1826
assert isinstance(d1, GWASCatalogStudyIndex)
1927
assert isinstance(d2, GWASCatalogAssociations)
28+
29+
30+
@pytest.mark.parametrize(
31+
"observed, expected",
32+
[
33+
# Test 1 - it shouldn't split
34+
(
35+
# observed - 2 associations with the same subStudy annotation
36+
[
37+
(
38+
"varA",
39+
"GCST003436",
40+
"Endometrial cancer|no_pvalue_text|EFO_1001512",
41+
),
42+
(
43+
"varB",
44+
"GCST003436",
45+
"Endometrial cancer|no_pvalue_text|EFO_1001512",
46+
),
47+
],
48+
# expected - 2 associations with the same unsplit updatedStudyId
49+
[
50+
("GCST003436",),
51+
("GCST003436",),
52+
],
53+
),
54+
# Test 2 - it should split
55+
(
56+
# observed - 2 associations with the different subStudy annotation
57+
[
58+
(
59+
"varA",
60+
"GCST003436",
61+
"Endometrial cancer|no_pvalue_text|EFO_1001512",
62+
),
63+
(
64+
"varB",
65+
"GCST003436",
66+
"Uterine carcinoma|no_pvalue_text|EFO_0002919",
67+
),
68+
],
69+
# expected - 2 associations with the same unsplit updatedStudyId
70+
[
71+
("GCST003436",),
72+
("GCST003436_2",),
73+
],
74+
),
75+
],
76+
)
77+
def test__resolve_study_id(
78+
spark: SparkSession, observed: list[Any], expected: list[Any]
79+
) -> None:
80+
"""Test _resolve_study_id."""
81+
observed_df = spark.createDataFrame(
82+
observed, schema=["variantId", "studyId", "subStudyDescription"]
83+
).select(
84+
GWASCatalogStudySplitter._resolve_study_id(
85+
f.col("studyId"), f.col("subStudyDescription").alias("updatedStudyId")
86+
)
87+
)
88+
expected_df = spark.createDataFrame(expected, schema=["updatedStudyId"])
89+
assert observed_df.collect() == expected_df.collect()

0 commit comments

Comments
 (0)