-
Notifications
You must be signed in to change notification settings - Fork 1
/
arrays_in_udfs.py
118 lines (90 loc) · 4.2 KB
/
arrays_in_udfs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import time
from datetime import datetime, date
from typing import List, Union
from pyspark.sql import DataFrame, SparkSession
import pyspark.sql.functions as spark_func
from pyspark.sql.types import IntegerType, StructType, StructField, ArrayType, StringType, DateType
from common import create_spark_session
from random import Random
def generate_dates_array_data(spark: SparkSession) -> DataFrame:
data = [
(1, ["2022-01-02", "2022-01-08", "2022-01-04", "2022-01-07"]),
(2, ["2022-01-03", "2022-01-01", "2022-01-02"]),
(3, ["2022-01-10", "2022-01-12", "2022-01-03", "2022-01-15", "2022-01-01"]),
(4, ["2022-01-22", "2022-01-21", "2022-01-10", "2022-01-14", "2022-01-24", "2022-01-15", "2022-01-06"]),
(5, ["2022-01-03"]),
]
data_schema = StructType([
StructField("id", IntegerType()),
StructField("dates", ArrayType(StringType())),
])
data_df = spark.createDataFrame(data=data, schema=data_schema)
return data_df
def generate_minimum_from_exmaple_data(spark: SparkSession) -> DataFrame:
data = [
(1, "2022-01-03", ["2022-01-02", "2022-01-08", "2022-01-04", "2022-01-07"]),
(2, "2022-01-05", ["2022-01-03", "2022-01-01", "2022-01-02"]),
(3, "2022-01-01", ["2022-01-10", "2022-01-12", "2022-01-03", "2022-01-15", "2022-01-01"]),
(4, "2022-01-11", ["2022-01-22", "2022-01-21", "2022-01-10", "2022-01-14", "2022-01-24", "2022-01-15", "2022-01-06"]),
(5, "2022-01-03", ["2022-01-03"]),
]
data_schema = StructType([
StructField("id", IntegerType()),
StructField("minimum_after", StringType()),
StructField("dates", ArrayType(StringType())),
])
data_df = spark\
.createDataFrame(data=data, schema=data_schema)\
.withColumn("minimum_after", spark_func.to_date("minimum_after", "yyyy-MM-dd"))
return data_df
@spark_func.udf(returnType=DateType())
def minimum_date(date_strs: List[str]) -> date:
return min([datetime.strptime(dt_str, "%Y-%m-%d") for dt_str in date_strs]).date()
@spark_func.udf(returnType=DateType())
def minimum_date_after(date_strs: List[str], minimum_after: date) -> Union[date, None]:
filtered_dates = [
datetime.strptime(dt_str, "%Y-%m-%d").date() for dt_str in date_strs
if datetime.strptime(dt_str, "%Y-%m-%d").date() >= minimum_after
]
if filtered_dates:
return min(filtered_dates)
else:
return None
if __name__ == '__main__':
spark = create_spark_session(app_name="Arrays in UDF")
# Example 1 - minimum date in array of dates
test_df = generate_dates_array_data(spark=spark)
result_df = test_df.withColumn("minimum_date", minimum_date(spark_func.col("dates")))
result_df.show(truncate=False)
result_df.printSchema()
# Example 1 using spark functions
result_df_alt = test_df.withColumn(
"minimum_date",
spark_func.array_min(spark_func.transform("dates", lambda dt_str: spark_func.to_date(dt_str, "yyyy-MM-dd"))))
result_df_alt.show(truncate=False)
# Example 2 - minimum date after given date in array of dates
test_df = generate_minimum_from_exmaple_data(spark=spark)
result_df = test_df.withColumn(
"minimum_date", minimum_date_after(spark_func.col("dates"), spark_func.col("minimum_after"))
)
result_df.show(truncate=False)
# Performance UDF vs. spark sql functions
start = time.time()
test_df = spark.read.parquet("test_data/*.parquet")
print(f"Data contains {test_df.count()} rows")
# Comment out one and run other, compare over multiple runs
# Using Spark UDF
min_dates_df = test_df\
.withColumn("minimum_date", minimum_date(spark_func.col("dates")))\
.groupby("minimum_date")\
.agg(spark_func.count(spark_func.lit(1)).alias("num_rows"))\
.show(truncate=False, n=1000)
# Using Spark SQL
min_dates_df = test_df.withColumn(
"minimum_date",
spark_func.array_min(spark_func.transform("dates", lambda dt_str: spark_func.to_date(dt_str, "yyyy-MM-dd")))
)\
.groupby("minimum_date")\
.agg(spark_func.count(spark_func.lit(1)).alias("num_rows"))\
.show(truncate=False, n=1000)
print(f"Took {time.time() - start} seconds.")