Skip to content

Commit

Permalink
Merge branch 'dev' into dc_susie_fix
Browse files Browse the repository at this point in the history
  • Loading branch information
project-defiant authored Dec 18, 2024
2 parents 4f8afaa + cb64cc5 commit 1994c67
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 23 deletions.
40 changes: 29 additions & 11 deletions src/gentropy/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@
from functools import reduce
from typing import TYPE_CHECKING, Any

import pyspark.sql.functions as f
from pyspark.sql.types import DoubleType
from pyspark.sql import DataFrame
from pyspark.sql import functions as f
from pyspark.sql import types as t
from pyspark.sql.window import Window
from typing_extensions import Self

Expand All @@ -18,25 +19,42 @@
if TYPE_CHECKING:
from enum import Enum

from pyspark.sql import Column, DataFrame
from pyspark.sql import Column
from pyspark.sql.types import StructType

from gentropy.common.session import Session


@dataclass
class Dataset(ABC):
"""Open Targets Gentropy Dataset.
"""Open Targets Gentropy Dataset Interface.
`Dataset` is a wrapper around a Spark DataFrame with a predefined schema. Schemas for each child dataset are described in the `schemas` module.
The `Dataset` interface is a wrapper around a Spark DataFrame with a predefined schema.
Class allows for overwriting the schema with `_schema` parameter.
If the `_schema` is not provided, the schema is inferred from the Dataset.get_schema specific
method which must be implemented by the child classes.
"""

_df: DataFrame
_schema: StructType
_schema: StructType | None = None

def __post_init__(self: Dataset) -> None:
"""Post init."""
self.validate_schema()
"""Post init.
Raises:
TypeError: If the type of the _df or _schema is not valid
"""
match self._df:
case DataFrame():
pass
case _:
raise TypeError(f"Invalid type for _df: {type(self._df)}")

match self._schema:
case None | t.StructType():
self.validate_schema()
case _:
raise TypeError(f"Invalid type for _schema: {type(self._schema)}")

@property
def df(self: Dataset) -> DataFrame:
Expand Down Expand Up @@ -64,7 +82,7 @@ def schema(self: Dataset) -> StructType:
Returns:
StructType: Dataframe expected schema
"""
return self._schema
return self._schema or self.get_schema()

@classmethod
def _process_class_params(
Expand Down Expand Up @@ -172,7 +190,7 @@ def validate_schema(self: Dataset) -> None:
Raises:
SchemaValidationError: If the DataFrame schema does not match the expected schema
"""
expected_schema = self._schema
expected_schema = self.schema
observed_schema = self._df.schema

# Unexpected fields in dataset
Expand Down Expand Up @@ -244,7 +262,7 @@ def drop_infinity_values(self: Self, *cols: str) -> Self:
if len(cols) == 0:
return self
inf_strings = ("Inf", "+Inf", "-Inf", "Infinity", "+Infinity", "-Infinity")
inf_values = [f.lit(v).cast(DoubleType()) for v in inf_strings]
inf_values = [f.lit(v).cast(t.DoubleType()) for v in inf_strings]
conditions = [f.col(c).isin(inf_values) for c in cols]
# reduce individual filter expressions with or statement
# to col("beta").isin([lit(Inf)]) | col("beta").isin([lit(Inf)])...
Expand Down
2 changes: 1 addition & 1 deletion src/gentropy/dataset/pairwise_ld.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def __post_init__(self: PairwiseLD) -> None:
), f"The number of rows in a pairwise LD table has to be square. Got: {row_count}"

self.dimension = (int(sqrt(row_count)), int(sqrt(row_count)))
self.validate_schema()
super().__post_init__()

@classmethod
def get_schema(cls: type[PairwiseLD]) -> StructType:
Expand Down
35 changes: 24 additions & 11 deletions tests/gentropy/dataset/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,32 +21,44 @@ def get_schema(cls) -> StructType:
return StructType([StructField("value", IntegerType(), False)])


class TestCoalesceAndRepartition:
class TestDataset:
"""Test TestDataset.coalesce and TestDataset.repartition."""

def test_repartition(self: TestCoalesceAndRepartition) -> None:
def test_repartition(self: TestDataset) -> None:
"""Test Dataset.repartition."""
initial_partitions = self.test_dataset._df.rdd.getNumPartitions()
new_partitions = initial_partitions + 1
self.test_dataset.repartition(new_partitions)
assert self.test_dataset._df.rdd.getNumPartitions() == new_partitions

def test_coalesce(self: TestCoalesceAndRepartition) -> None:
def test_coalesce(self: TestDataset) -> None:
"""Test Dataset.coalesce."""
initial_partitions = self.test_dataset._df.rdd.getNumPartitions()
new_partitions = initial_partitions - 1 if initial_partitions > 1 else 1
self.test_dataset.coalesce(new_partitions)
assert self.test_dataset._df.rdd.getNumPartitions() == new_partitions

def test_initialize_without_schema(self: TestDataset, spark: SparkSession) -> None:
"""Test if Dataset derived class collects the schema from assets if schema is not provided."""
df = spark.createDataFrame([(1,)], schema=MockDataset.get_schema())
ds = MockDataset(_df=df)
assert (
ds.schema == MockDataset.get_schema()
), "Schema should be inferred from df"

def test_passing_incorrect_types(self: TestDataset, spark: SparkSession) -> None:
"""Test if passing incorrect object types to Dataset raises an error."""
with pytest.raises(TypeError):
MockDataset(_df="not a dataframe")
with pytest.raises(TypeError):
MockDataset(_df=self.df, _schema="not a schema")

@pytest.fixture(autouse=True)
def _setup(self: TestCoalesceAndRepartition, spark: SparkSession) -> None:
def _setup(self: TestDataset, spark: SparkSession) -> None:
"""Setup fixture."""
self.test_dataset = MockDataset(
_df=spark.createDataFrame(
[(1,), (2,), (3,)], schema=MockDataset.get_schema()
),
_schema=MockDataset.get_schema(),
)
df = spark.createDataFrame([(1,), (2,), (3,)], schema=MockDataset.get_schema())
self.df = df
self.test_dataset = MockDataset(_df=df, _schema=MockDataset.get_schema())


def test_dataset_filter(mock_study_index: StudyIndex) -> None:
Expand All @@ -68,6 +80,7 @@ def test_dataset_drop_infinity_values() -> None:
rows = [(v,) for v in data]
schema = StructType([StructField("field", DoubleType())])
input_df = spark.createDataFrame(rows, schema=schema)

assert input_df.count() == 7
# run without specifying *cols results in no filtering
ds = MockDataset(_df=input_df, _schema=schema)
Expand All @@ -76,7 +89,7 @@ def test_dataset_drop_infinity_values() -> None:
assert ds.drop_infinity_values("field").df.count() == 1


def test__process_class_params(spark: SparkSession) -> None:
def test_process_class_params(spark: SparkSession) -> None:
"""Test splitting of parameters between class and spark parameters."""
params = {
"_df": spark.createDataFrame([(1,)], schema=MockDataset.get_schema()),
Expand Down

0 comments on commit 1994c67

Please sign in to comment.