Skip to content

Commit b60a19f

Browse files
addramirSzymon Szyszkowskiproject-defiant
authored
fix(SummaryStatistics): fix in sanity_filter (#623)
* fix(SummaryStatistics): fix in sanity_filter * fix: adding prune of inf values * fix(dataset): removal of inf values from beta and stderr * fix: fix in test and sanity filter --------- Co-authored-by: Szymon Szyszkowski <ss60@mib117351s.internal.sanger.ac.uk> Co-authored-by: Szymon Szyszkowski <69353402+project-defiant@users.noreply.github.com>
1 parent 48cf2a8 commit b60a19f

File tree

4 files changed

+107
-9
lines changed

4 files changed

+107
-9
lines changed

src/gentropy/dataset/dataset.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,14 @@
11
"""Dataset class for gentropy."""
2+
23
from __future__ import annotations
34

45
from abc import ABC, abstractmethod
56
from dataclasses import dataclass
7+
from functools import reduce
68
from typing import TYPE_CHECKING, Any
79

10+
import pyspark.sql.functions as f
11+
from pyspark.sql.types import DoubleType
812
from typing_extensions import Self
913

1014
from gentropy.common.schemas import flatten_schema
@@ -164,6 +168,29 @@ def validate_schema(self: Dataset) -> None:
164168
f"The following fields present differences in their datatypes: {fields_with_different_observed_datatype}."
165169
)
166170

171+
def drop_infinity_values(self: Self, *cols: str) -> Self:
172+
"""Drop infinity values from Double typed column.
173+
174+
Infinity type reference - https://spark.apache.org/docs/latest/sql-ref-datatypes.html#floating-point-special-values
175+
The implementation comes from https://stackoverflow.com/questions/34432998/how-to-replace-infinity-in-pyspark-dataframe
176+
177+
Args:
178+
*cols (str): names of the columns to check for infinite values, these should be of DoubleType only!
179+
180+
Returns:
181+
Self: Dataset after removing infinite values
182+
"""
183+
if len(cols) == 0:
184+
return self
185+
inf_strings = ("Inf", "+Inf", "-Inf", "Infinity", "+Infinity", "-Infinity")
186+
inf_values = [f.lit(v).cast(DoubleType()) for v in inf_strings]
187+
conditions = [f.col(c).isin(inf_values) for c in cols]
188+
# reduce individual filter expressions with or statement
189+
# to col("beta").isin([lit(Inf)]) | col("beta").isin([lit(Inf)])...
190+
condition = reduce(lambda a, b: a | b, conditions)
191+
self.df = self._df.filter(~condition)
192+
return self
193+
167194
def persist(self: Self) -> Self:
168195
"""Persist in memory the DataFrame included in the Dataset.
169196

src/gentropy/dataset/summary_statistics.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
"""Summary satistics dataset."""
2+
23
from __future__ import annotations
34

45
from dataclasses import dataclass
@@ -108,9 +109,12 @@ def sanity_filter(self: SummaryStatistics) -> SummaryStatistics:
108109
"""The function filters the summary statistics by sanity filters.
109110
110111
The function filters the summary statistics by the following filters:
111-
- The p-value should not be eqaul 1.
112-
- The beta and se should not be equal 0.
112+
- The p-value should be less than 1.
113+
- The pValueMantissa should be greater than 0.
114+
- The beta should not be equal 0.
113115
- The p-value, beta and se should not be NaN.
116+
- The se should be positive.
117+
- The beta and se should not be infinite.
114118
115119
Returns:
116120
SummaryStatistics: The filtered summary statistics.
@@ -119,13 +123,15 @@ def sanity_filter(self: SummaryStatistics) -> SummaryStatistics:
119123
gwas_df = gwas_df.dropna(
120124
subset=["beta", "standardError", "pValueMantissa", "pValueExponent"]
121125
)
122-
123-
gwas_df = gwas_df.filter((f.col("beta") != 0) & (f.col("standardError") != 0))
126+
gwas_df = gwas_df.filter((f.col("beta") != 0) & (f.col("standardError") > 0))
124127
gwas_df = gwas_df.filter(
125-
f.col("pValueMantissa") * 10 ** f.col("pValueExponent") != 1
128+
(f.col("pValueMantissa") * 10 ** f.col("pValueExponent") < 1)
129+
& (f.col("pValueMantissa") > 0)
126130
)
127-
128-
return SummaryStatistics(
131+
cols = ["beta", "standardError"]
132+
summary_stats = SummaryStatistics(
129133
_df=gwas_df,
130134
_schema=SummaryStatistics.get_schema(),
131-
)
135+
).drop_infinity_values(*cols)
136+
137+
return summary_stats

tests/gentropy/dataset/test_dataset.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,18 @@
22

33
from __future__ import annotations
44

5+
import numpy as np
56
import pyspark.sql.functions as f
67
import pytest
78
from gentropy.dataset.dataset import Dataset
89
from gentropy.dataset.study_index import StudyIndex
910
from pyspark.sql import SparkSession
10-
from pyspark.sql.types import IntegerType, StructField, StructType
11+
from pyspark.sql.types import (
12+
DoubleType,
13+
IntegerType,
14+
StructField,
15+
StructType,
16+
)
1117

1218

1319
class MockDataset(Dataset):
@@ -57,3 +63,18 @@ def test_dataset_filter(mock_study_index: StudyIndex) -> None:
5763
filtered.df.select("studyType").distinct().toPandas()["studyType"].to_list()[0]
5864
== expected_filter_value
5965
), "Filtering failed."
66+
67+
68+
def test_dataset_drop_infinity_values() -> None:
69+
"""drop_infinity_values method shoud remove inf value from standardError field."""
70+
spark = SparkSession.getActiveSession()
71+
data = [np.Infinity, -np.Infinity, np.inf, -np.inf, np.Inf, -np.Inf, 5.1]
72+
rows = [(v,) for v in data]
73+
schema = StructType([StructField("field", DoubleType())])
74+
input_df = spark.createDataFrame(rows, schema=schema)
75+
assert input_df.count() == 7
76+
# run without specifying *cols results in no filtering
77+
ds = MockDataset(_df=input_df, _schema=schema)
78+
assert ds.drop_infinity_values().df.count() == 7
79+
# otherwise drop all columns
80+
assert ds.drop_infinity_values("field").df.count() == 1

tests/gentropy/method/test_qc_of_sumstats.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
import numpy as np
66
import pandas as pd
77
import pyspark.sql.functions as f
8+
import pytest
9+
from gentropy.common.session import Session
810
from gentropy.dataset.summary_statistics import SummaryStatistics
911
from gentropy.method.sumstat_quality_controls import SummaryStatisticsQC
1012
from pyspark.sql.functions import rand, when
@@ -61,3 +63,45 @@ def test_several_studyid(
6163
)
6264
QC = QC.toPandas()
6365
assert QC.shape == (2, 8)
66+
67+
68+
def test_sanity_filter_remove_inf_values(
69+
session: Session,
70+
) -> None:
71+
"""Sanity filter remove inf value from standardError field."""
72+
data = [
73+
(
74+
"GCST012234",
75+
"10_73856419_C_A",
76+
10,
77+
73856419,
78+
np.Infinity,
79+
1,
80+
3.1324,
81+
-650,
82+
None,
83+
0.4671,
84+
),
85+
(
86+
"GCST012234",
87+
"14_98074714_G_C",
88+
14,
89+
98074714,
90+
6.697,
91+
2,
92+
5.4275,
93+
-2890,
94+
None,
95+
0.4671,
96+
),
97+
]
98+
input_df = session.spark.createDataFrame(
99+
data=data, schema=SummaryStatistics.get_schema()
100+
)
101+
summary_stats = SummaryStatistics(
102+
_df=input_df, _schema=SummaryStatistics.get_schema()
103+
)
104+
stats_after_filter = summary_stats.sanity_filter().df.collect()
105+
assert input_df.count() == 2
106+
assert len(stats_after_filter) == 1
107+
assert stats_after_filter[0]["beta"] - 6.697 == pytest.approx(0)

0 commit comments

Comments
 (0)