diff --git a/src/gentropy/dataset/dataset.py b/src/gentropy/dataset/dataset.py index 67fe05eaf..fa06faec2 100644 --- a/src/gentropy/dataset/dataset.py +++ b/src/gentropy/dataset/dataset.py @@ -8,8 +8,9 @@ from functools import reduce from typing import TYPE_CHECKING, Any -import pyspark.sql.functions as f -from pyspark.sql.types import DoubleType +from pyspark.sql import DataFrame +from pyspark.sql import functions as f +from pyspark.sql import types as t from pyspark.sql.window import Window from typing_extensions import Self @@ -18,7 +19,7 @@ if TYPE_CHECKING: from enum import Enum - from pyspark.sql import Column, DataFrame + from pyspark.sql import Column from pyspark.sql.types import StructType from gentropy.common.session import Session @@ -26,17 +27,34 @@ @dataclass class Dataset(ABC): - """Open Targets Gentropy Dataset. + """Open Targets Gentropy Dataset Interface. - `Dataset` is a wrapper around a Spark DataFrame with a predefined schema. Schemas for each child dataset are described in the `schemas` module. + The `Dataset` interface is a wrapper around a Spark DataFrame with a predefined schema. + Class allows for overwriting the schema with `_schema` parameter. + If the `_schema` is not provided, the schema is inferred from the Dataset.get_schema specific + method which must be implemented by the child classes. """ _df: DataFrame - _schema: StructType + _schema: StructType | None = None def __post_init__(self: Dataset) -> None: - """Post init.""" - self.validate_schema() + """Post init. + + Raises: + TypeError: If the type of the _df or _schema is not valid + """ + match self._df: + case DataFrame(): + pass + case _: + raise TypeError(f"Invalid type for _df: {type(self._df)}") + + match self._schema: + case None | t.StructType(): + self.validate_schema() + case _: + raise TypeError(f"Invalid type for _schema: {type(self._schema)}") @property def df(self: Dataset) -> DataFrame: @@ -64,7 +82,7 @@ def schema(self: Dataset) -> StructType: Returns: StructType: Dataframe expected schema """ - return self._schema + return self._schema or self.get_schema() @classmethod def _process_class_params( @@ -172,7 +190,7 @@ def validate_schema(self: Dataset) -> None: Raises: SchemaValidationError: If the DataFrame schema does not match the expected schema """ - expected_schema = self._schema + expected_schema = self.schema observed_schema = self._df.schema # Unexpected fields in dataset @@ -244,7 +262,7 @@ def drop_infinity_values(self: Self, *cols: str) -> Self: if len(cols) == 0: return self inf_strings = ("Inf", "+Inf", "-Inf", "Infinity", "+Infinity", "-Infinity") - inf_values = [f.lit(v).cast(DoubleType()) for v in inf_strings] + inf_values = [f.lit(v).cast(t.DoubleType()) for v in inf_strings] conditions = [f.col(c).isin(inf_values) for c in cols] # reduce individual filter expressions with or statement # to col("beta").isin([lit(Inf)]) | col("beta").isin([lit(Inf)])... diff --git a/src/gentropy/dataset/pairwise_ld.py b/src/gentropy/dataset/pairwise_ld.py index b64592094..ab68a74ab 100644 --- a/src/gentropy/dataset/pairwise_ld.py +++ b/src/gentropy/dataset/pairwise_ld.py @@ -38,7 +38,7 @@ def __post_init__(self: PairwiseLD) -> None: ), f"The number of rows in a pairwise LD table has to be square. Got: {row_count}" self.dimension = (int(sqrt(row_count)), int(sqrt(row_count))) - self.validate_schema() + super().__post_init__() @classmethod def get_schema(cls: type[PairwiseLD]) -> StructType: diff --git a/tests/gentropy/dataset/test_dataset.py b/tests/gentropy/dataset/test_dataset.py index 7c61f3f52..96a96ec27 100644 --- a/tests/gentropy/dataset/test_dataset.py +++ b/tests/gentropy/dataset/test_dataset.py @@ -21,32 +21,44 @@ def get_schema(cls) -> StructType: return StructType([StructField("value", IntegerType(), False)]) -class TestCoalesceAndRepartition: +class TestDataset: """Test TestDataset.coalesce and TestDataset.repartition.""" - def test_repartition(self: TestCoalesceAndRepartition) -> None: + def test_repartition(self: TestDataset) -> None: """Test Dataset.repartition.""" initial_partitions = self.test_dataset._df.rdd.getNumPartitions() new_partitions = initial_partitions + 1 self.test_dataset.repartition(new_partitions) assert self.test_dataset._df.rdd.getNumPartitions() == new_partitions - def test_coalesce(self: TestCoalesceAndRepartition) -> None: + def test_coalesce(self: TestDataset) -> None: """Test Dataset.coalesce.""" initial_partitions = self.test_dataset._df.rdd.getNumPartitions() new_partitions = initial_partitions - 1 if initial_partitions > 1 else 1 self.test_dataset.coalesce(new_partitions) assert self.test_dataset._df.rdd.getNumPartitions() == new_partitions + def test_initialize_without_schema(self: TestDataset, spark: SparkSession) -> None: + """Test if Dataset derived class collects the schema from assets if schema is not provided.""" + df = spark.createDataFrame([(1,)], schema=MockDataset.get_schema()) + ds = MockDataset(_df=df) + assert ( + ds.schema == MockDataset.get_schema() + ), "Schema should be inferred from df" + + def test_passing_incorrect_types(self: TestDataset, spark: SparkSession) -> None: + """Test if passing incorrect object types to Dataset raises an error.""" + with pytest.raises(TypeError): + MockDataset(_df="not a dataframe") + with pytest.raises(TypeError): + MockDataset(_df=self.df, _schema="not a schema") + @pytest.fixture(autouse=True) - def _setup(self: TestCoalesceAndRepartition, spark: SparkSession) -> None: + def _setup(self: TestDataset, spark: SparkSession) -> None: """Setup fixture.""" - self.test_dataset = MockDataset( - _df=spark.createDataFrame( - [(1,), (2,), (3,)], schema=MockDataset.get_schema() - ), - _schema=MockDataset.get_schema(), - ) + df = spark.createDataFrame([(1,), (2,), (3,)], schema=MockDataset.get_schema()) + self.df = df + self.test_dataset = MockDataset(_df=df, _schema=MockDataset.get_schema()) def test_dataset_filter(mock_study_index: StudyIndex) -> None: @@ -68,6 +80,7 @@ def test_dataset_drop_infinity_values() -> None: rows = [(v,) for v in data] schema = StructType([StructField("field", DoubleType())]) input_df = spark.createDataFrame(rows, schema=schema) + assert input_df.count() == 7 # run without specifying *cols results in no filtering ds = MockDataset(_df=input_df, _schema=schema) @@ -76,7 +89,7 @@ def test_dataset_drop_infinity_values() -> None: assert ds.drop_infinity_values("field").df.count() == 1 -def test__process_class_params(spark: SparkSession) -> None: +def test_process_class_params(spark: SparkSession) -> None: """Test splitting of parameters between class and spark parameters.""" params = { "_df": spark.createDataFrame([(1,)], schema=MockDataset.get_schema()),