Skip to content

Commit

Permalink
add randomDF to SparkleTestCase
Browse files Browse the repository at this point in the history
  • Loading branch information
machielg committed Nov 18, 2019
1 parent 14f28f6 commit 7dcda24
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 2 deletions.
2 changes: 1 addition & 1 deletion build.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
don't leave any files in your workspace. There is one convenience method for asserting dataframe equality.
"""
default_task = ["clean", "analyze", "publish"]
version = "1.0.1"
version = "1.1.0.dev"

url = "https://github.com/machielg/sparkle-test/"
license = "GPLv3+"
Expand Down
64 changes: 63 additions & 1 deletion src/main/python/sparkle_test/test_case.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,74 @@
import os
import shutil
import string
import tempfile
import unittest
import warnings
from abc import ABC
from datetime import datetime, date
from inspect import signature
from random import choice, randint

import pyspark
from pandas.util.testing import assert_frame_equal
# noinspection PyProtectedMember
from pyspark import SQLContext
from pyspark.sql import SparkSession
from pyspark.sql import SparkSession, DataFrame
from pyspark.sql.types import StructType, StructField, DataType, IntegerType


def types_dict():
type_classes = [cls for name, cls in pyspark.sql.types.__dict__.items() if
isinstance(cls, type) and
issubclass(cls, DataType) and
signature(object.__init__) == signature(cls.__init__)]

return dict([(c().simpleString(), c) for c in type_classes])


class RandomDF:
types = types_dict()

def __init__(self, spark: SparkSession, *col_name: str):
self.spark = spark
if not col_name:
self.__cols = [choice(string.ascii_letters)]
else:
self.__cols = col_name

def generate(self):
df = self.spark.createDataFrame(self._vals, self._schema)
return df

@property
def _vals(self):
vals = zip(*[[self.rand_val() for i in range(self._length)] for c in self._cols])
return vals

@property
def _cols(self):
return [c.split(":")[0] for c in self.__cols]

@property
def _schema(self):
schema = StructType([StructField(c.split(":")[0], self._to_type(c)) for c in self.__cols])
return schema

@staticmethod
def rand_val():
return choice([None, randint(-1, 10)])

@property
def _length(self):
df_lenght = randint(1, 10)
return df_lenght

def _to_type(self, col_and_type: str):
if ":" in col_and_type:
tipe = col_and_type.split(":")[1]
return self.types[tipe]()
else:
return IntegerType()


class SparkleTestCase(unittest.TestCase, ABC):
Expand Down Expand Up @@ -143,3 +202,6 @@ def dt(date_time_str: str) -> datetime:
:return: datetime
"""
return datetime.strptime(date_time_str, '%Y-%m-%d %H:%M:%S')

def randomDF(self, *col_name: str) -> DataFrame:
return RandomDF(self.spark, *col_name).generate()
22 changes: 22 additions & 0 deletions src/unittest/python/test_case_tests.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import pathlib

from pyspark.sql.types import LongType

from sparkle_test import SparkleTestCase


Expand Down Expand Up @@ -27,3 +29,23 @@ def _find_log_files_outside_target_dir(self) -> list:

log_files = [f for f in pathlib.Path(path).rglob('*.log') if target_dir not in f.parents]
return log_files

def test_random_df(self):
df = self.randomDF()
self.assertIsNotNone(df)
self.assertGreaterEqual(df.count(), 1)
self.assertGreaterEqual(len(df.columns), 1)

def test_random_df_with_cols(self):
df = self.randomDF("a", "b")
self.assertEqual(2, len(df.columns))
self.assertEqual("a", df.columns[0])
self.assertEqual("b", df.columns[1])
self.assertGreaterEqual(df.count(), 1)

def test_random_df_with_cols_and_type(self):
df = self.randomDF("a:bigint")
self.assertEqual(1, len(df.columns))
self.assertEqual("a", df.columns[0])
self.assertEqual(LongType(), df.schema.fields[0].dataType)
self.assertGreaterEqual(df.count(), 1)

0 comments on commit 7dcda24

Please sign in to comment.