diff --git a/build.py b/build.py index db96a0f..c43c009 100644 --- a/build.py +++ b/build.py @@ -13,7 +13,7 @@ don't leave any files in your workspace. There is one convenience method for asserting dataframe equality. """ default_task = ["clean", "analyze", "publish"] -version = "1.0.1" +version = "1.1.0.dev" url = "https://github.com/machielg/sparkle-test/" license = "GPLv3+" diff --git a/src/main/python/sparkle_test/test_case.py b/src/main/python/sparkle_test/test_case.py index 8961bc9..9958dfe 100644 --- a/src/main/python/sparkle_test/test_case.py +++ b/src/main/python/sparkle_test/test_case.py @@ -1,15 +1,74 @@ import os import shutil +import string import tempfile import unittest import warnings from abc import ABC from datetime import datetime, date +from inspect import signature +from random import choice, randint +import pyspark from pandas.util.testing import assert_frame_equal # noinspection PyProtectedMember from pyspark import SQLContext -from pyspark.sql import SparkSession +from pyspark.sql import SparkSession, DataFrame +from pyspark.sql.types import StructType, StructField, DataType, IntegerType + + +def types_dict(): + type_classes = [cls for name, cls in pyspark.sql.types.__dict__.items() if + isinstance(cls, type) and + issubclass(cls, DataType) and + signature(object.__init__) == signature(cls.__init__)] + + return dict([(c().simpleString(), c) for c in type_classes]) + + +class RandomDF: + types = types_dict() + + def __init__(self, spark: SparkSession, *col_name: str): + self.spark = spark + if not col_name: + self.__cols = [choice(string.ascii_letters)] + else: + self.__cols = col_name + + def generate(self): + df = self.spark.createDataFrame(self._vals, self._schema) + return df + + @property + def _vals(self): + vals = zip(*[[self.rand_val() for i in range(self._length)] for c in self._cols]) + return vals + + @property + def _cols(self): + return [c.split(":")[0] for c in self.__cols] + + @property + def _schema(self): + schema = StructType([StructField(c.split(":")[0], self._to_type(c)) for c in self.__cols]) + return schema + + @staticmethod + def rand_val(): + return choice([None, randint(-1, 10)]) + + @property + def _length(self): + df_lenght = randint(1, 10) + return df_lenght + + def _to_type(self, col_and_type: str): + if ":" in col_and_type: + tipe = col_and_type.split(":")[1] + return self.types[tipe]() + else: + return IntegerType() class SparkleTestCase(unittest.TestCase, ABC): @@ -143,3 +202,6 @@ def dt(date_time_str: str) -> datetime: :return: datetime """ return datetime.strptime(date_time_str, '%Y-%m-%d %H:%M:%S') + + def randomDF(self, *col_name: str) -> DataFrame: + return RandomDF(self.spark, *col_name).generate() diff --git a/src/unittest/python/test_case_tests.py b/src/unittest/python/test_case_tests.py index 1ccccb1..cfc4ae4 100644 --- a/src/unittest/python/test_case_tests.py +++ b/src/unittest/python/test_case_tests.py @@ -1,5 +1,7 @@ import pathlib +from pyspark.sql.types import LongType + from sparkle_test import SparkleTestCase @@ -27,3 +29,23 @@ def _find_log_files_outside_target_dir(self) -> list: log_files = [f for f in pathlib.Path(path).rglob('*.log') if target_dir not in f.parents] return log_files + + def test_random_df(self): + df = self.randomDF() + self.assertIsNotNone(df) + self.assertGreaterEqual(df.count(), 1) + self.assertGreaterEqual(len(df.columns), 1) + + def test_random_df_with_cols(self): + df = self.randomDF("a", "b") + self.assertEqual(2, len(df.columns)) + self.assertEqual("a", df.columns[0]) + self.assertEqual("b", df.columns[1]) + self.assertGreaterEqual(df.count(), 1) + + def test_random_df_with_cols_and_type(self): + df = self.randomDF("a:bigint") + self.assertEqual(1, len(df.columns)) + self.assertEqual("a", df.columns[0]) + self.assertEqual(LongType(), df.schema.fields[0].dataType) + self.assertGreaterEqual(df.count(), 1) \ No newline at end of file