1
1
"""Tests GWAS Catalog study splitter."""
2
2
from __future__ import annotations
3
3
4
+ from typing import TYPE_CHECKING , Any
5
+
6
+ import pyspark .sql .functions as f
7
+ import pytest
8
+
4
9
from otg .datasource .gwas_catalog .associations import GWASCatalogAssociations
5
10
from otg .datasource .gwas_catalog .study_index import GWASCatalogStudyIndex
6
11
from otg .datasource .gwas_catalog .study_splitter import GWASCatalogStudySplitter
7
12
13
+ if TYPE_CHECKING :
14
+ from pyspark .sql import SparkSession
15
+
8
16
9
17
def test_gwas_catalog_splitter_split (
10
18
mock_study_index_gwas_catalog : GWASCatalogStudyIndex ,
@@ -17,3 +25,65 @@ def test_gwas_catalog_splitter_split(
17
25
18
26
assert isinstance (d1 , GWASCatalogStudyIndex )
19
27
assert isinstance (d2 , GWASCatalogAssociations )
28
+
29
+
30
+ @pytest .mark .parametrize (
31
+ "observed, expected" ,
32
+ [
33
+ # Test 1 - it shouldn't split
34
+ (
35
+ # observed - 2 associations with the same subStudy annotation
36
+ [
37
+ (
38
+ "varA" ,
39
+ "GCST003436" ,
40
+ "Endometrial cancer|no_pvalue_text|EFO_1001512" ,
41
+ ),
42
+ (
43
+ "varB" ,
44
+ "GCST003436" ,
45
+ "Endometrial cancer|no_pvalue_text|EFO_1001512" ,
46
+ ),
47
+ ],
48
+ # expected - 2 associations with the same unsplit updatedStudyId
49
+ [
50
+ ("GCST003436" ,),
51
+ ("GCST003436" ,),
52
+ ],
53
+ ),
54
+ # Test 2 - it should split
55
+ (
56
+ # observed - 2 associations with the different subStudy annotation
57
+ [
58
+ (
59
+ "varA" ,
60
+ "GCST003436" ,
61
+ "Endometrial cancer|no_pvalue_text|EFO_1001512" ,
62
+ ),
63
+ (
64
+ "varB" ,
65
+ "GCST003436" ,
66
+ "Uterine carcinoma|no_pvalue_text|EFO_0002919" ,
67
+ ),
68
+ ],
69
+ # expected - 2 associations with the same unsplit updatedStudyId
70
+ [
71
+ ("GCST003436" ,),
72
+ ("GCST003436_2" ,),
73
+ ],
74
+ ),
75
+ ],
76
+ )
77
+ def test__resolve_study_id (
78
+ spark : SparkSession , observed : list [Any ], expected : list [Any ]
79
+ ) -> None :
80
+ """Test _resolve_study_id."""
81
+ observed_df = spark .createDataFrame (
82
+ observed , schema = ["variantId" , "studyId" , "subStudyDescription" ]
83
+ ).select (
84
+ GWASCatalogStudySplitter ._resolve_study_id (
85
+ f .col ("studyId" ), f .col ("subStudyDescription" ).alias ("updatedStudyId" )
86
+ )
87
+ )
88
+ expected_df = spark .createDataFrame (expected , schema = ["updatedStudyId" ])
89
+ assert observed_df .collect () == expected_df .collect ()
0 commit comments