Skip to content

Commit 705a3a2

Browse files
Kimahrimandongjoon-hyun
authored andcommitted
[SPARK-54153][PYTHON] Support profiling iterator based Python UDFs
### What changes were proposed in this pull request? Updates the v2 Spark-session based Python UDF profiler to support profiling iterator based UDFs. ```python from collections.abc import Iterator from pstats import SortKey import pyarrow as pa df = spark.range(100000) def map_func(iter: Iterator[pa.RecordBatch]) -> Iterator[pa.RecordBatch]: for batch in iter: yield pa.RecordBatch.from_arrays([pa.compute.add(batch.column("id"), 10)], ["id"]) spark.conf.set('spark.sql.pyspark.udf.profiler', 'perf') df.mapInArrow(map_func, df.schema).collect() spark.conf.set('spark.sql.pyspark.udf.profiler', 'memory') df.mapInArrow(map_func, df.schema).collect() for stats in spark.profile.profiler_collector._perf_profile_results.values(): stats.sort_stats(SortKey.CUMULATIVE).print_stats(20) spark.profile.show(type="memory") ``` ``` 1395288 function calls (1359888 primitive calls) in 2.850 seconds Ordered by: cumulative time List reduced from 1546 to 20 due to restriction <20> ncalls tottime percall cumtime percall filename:lineno(function) 416 0.008 0.000 5.901 0.014 __init__.py:1(<module>) 424/24 0.000 0.000 2.850 0.119 {built-in method builtins.next} 24 0.001 0.000 2.850 0.119 test.py:11(map_func) 16 0.002 0.000 2.646 0.165 compute.py:244(wrapper) 2752/24 0.016 0.000 2.642 0.110 <frozen importlib._bootstrap>:1349(_find_and_load) 2752/24 0.013 0.000 2.641 0.110 <frozen importlib._bootstrap>:1304(_find_and_load_unlocked) 64 0.002 0.000 2.618 0.041 api.py:1(<module>) 2704/24 0.009 0.000 2.612 0.109 <frozen importlib._bootstrap>:911(_load_unlocked) 2264/24 0.005 0.000 2.611 0.109 <frozen importlib._bootstrap_external>:993(exec_module) 6336/48 0.004 0.000 2.591 0.054 <frozen importlib._bootstrap>:480(_call_with_frames_removed) 2400/24 0.023 0.000 2.591 0.108 {built-in method builtins.exec} 24 0.002 0.000 1.927 0.080 generic.py:1(<module>) 520/320 0.002 0.000 1.429 0.004 {built-in method builtins.__import__} 4312/2896 0.006 0.000 1.190 0.000 <frozen importlib._bootstrap>:1390(_handle_fromlist) 8 0.001 0.000 1.014 0.127 frame.py:1(<module>) 4392/4312 0.069 0.000 0.953 0.000 {built-in method builtins.__build_class__} 2264 0.030 0.000 0.562 0.000 <frozen importlib._bootstrap_external>:1066(get_code) 16 0.001 0.000 0.551 0.034 indexing.py:1(<module>) 24 0.001 0.000 0.451 0.019 datetimes.py:1(<module>) 16 0.001 0.000 0.401 0.025 datetimelike.py:1(<module>) ============================================================ Profile of UDF<id=3> ============================================================ Filename: /data/projects/spark/python/test.py Line # Mem usage Increment Occurrences Line Contents ============================================================= 11 1212.1 MiB 1212.1 MiB 8 def map_func(iter: Iterator[pa.RecordBatch]) -> Iterator[pa.RecordBatch]: 12 1212.1 MiB -0.2 MiB 24 for batch in iter: 13 1212.1 MiB -0.2 MiB 32 yield pa.RecordBatch.from_arrays([pa.compute.add(batch.column("id"), 10)], ["id"]) ``` ### Why are the changes needed? To add valuable profiling support to all types of UDFs. ### Does this PR introduce _any_ user-facing change? Yes, iterator based Python UDFs can now be profiled with the SQL config based profiler. ### How was this patch tested? Updated UTs that were specifically testing that this wasn't supported to show they are now supported. ### Was this patch authored or co-authored using generative AI tooling? No Closes #52853 from Kimahriman/udf-iter-profiler. Authored-by: Adam Binford <adamq43@gmail.com> Signed-off-by: Dongjoon Hyun <dongjoon@apache.org> (cherry picked from commit 6a369f9) Signed-off-by: Dongjoon Hyun <dongjoon@apache.org>
1 parent 3ccc769 commit 705a3a2

File tree

6 files changed

+157
-64
lines changed

6 files changed

+157
-64
lines changed

python/pyspark/sql/tests/test_udf_profiler.py

Lines changed: 63 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from pyspark import SparkConf
2929
from pyspark.errors import PySparkValueError
3030
from pyspark.sql import SparkSession
31+
from pyspark.sql.datasource import DataSource, DataSourceReader
3132
from pyspark.sql.functions import col, arrow_udf, pandas_udf, udf
3233
from pyspark.sql.window import Window
3334
from pyspark.profiler import UDFBasicProfiler
@@ -325,59 +326,47 @@ def add2(x):
325326
not have_pandas or not have_pyarrow,
326327
cast(str, pandas_requirement_message or pyarrow_requirement_message),
327328
)
328-
def test_perf_profiler_pandas_udf_iterator_not_supported(self):
329+
def test_perf_profiler_pandas_udf_iterator(self):
329330
import pandas as pd
330331

331332
@pandas_udf("long")
332-
def add1(x):
333-
return x + 1
334-
335-
@pandas_udf("long")
336-
def add2(iter: Iterator[pd.Series]) -> Iterator[pd.Series]:
333+
def add(iter: Iterator[pd.Series]) -> Iterator[pd.Series]:
337334
for s in iter:
338-
yield s + 2
335+
yield s + 1
339336

340337
with self.sql_conf({"spark.sql.pyspark.udf.profiler": "perf"}):
341-
df = self.spark.range(10, numPartitions=2).select(
342-
add1("id"), add2("id"), add1("id"), add2(col("id") + 1)
343-
)
338+
df = self.spark.range(10, numPartitions=2).select(add("id"))
344339
df.collect()
345340

346341
self.assertEqual(1, len(self.profile_results), str(self.profile_results.keys()))
347342

348343
for id in self.profile_results:
349-
self.assert_udf_profile_present(udf_id=id, expected_line_count_prefix=2)
344+
self.assert_udf_profile_present(udf_id=id, expected_line_count_prefix=4)
350345

351346
@unittest.skipIf(not have_pyarrow, pyarrow_requirement_message)
352-
def test_perf_profiler_arrow_udf_iterator_not_supported(self):
347+
def test_perf_profiler_arrow_udf_iterator(self):
353348
import pyarrow as pa
354349

355350
@arrow_udf("long")
356-
def add1(x):
357-
return pa.compute.add(x, 1)
358-
359-
@arrow_udf("long")
360-
def add2(iter: Iterator[pa.Array]) -> Iterator[pa.Array]:
351+
def add(iter: Iterator[pa.Array]) -> Iterator[pa.Array]:
361352
for s in iter:
362-
yield pa.compute.add(s, 2)
353+
yield pa.compute.add(s, 1)
363354

364355
with self.sql_conf({"spark.sql.pyspark.udf.profiler": "perf"}):
365-
df = self.spark.range(10, numPartitions=2).select(
366-
add1("id"), add2("id"), add1("id"), add2(col("id") + 1)
367-
)
356+
df = self.spark.range(10, numPartitions=2).select(add("id"))
368357
df.collect()
369358

370359
self.assertEqual(1, len(self.profile_results), str(self.profile_results.keys()))
371360

372361
for id in self.profile_results:
373-
self.assert_udf_profile_present(udf_id=id, expected_line_count_prefix=2)
362+
self.assert_udf_profile_present(udf_id=id, expected_line_count_prefix=4)
374363

375364
@unittest.skipIf(
376365
not have_pandas or not have_pyarrow,
377366
cast(str, pandas_requirement_message or pyarrow_requirement_message),
378367
)
379-
def test_perf_profiler_map_in_pandas_not_supported(self):
380-
df = self.spark.createDataFrame([(1, 21), (2, 30)], ("id", "age"))
368+
def test_perf_profiler_map_in_pandas(self):
369+
df = self.spark.createDataFrame([(1, 21), (2, 30)], ("id", "age")).repartition(1)
381370

382371
def filter_func(iterator):
383372
for pdf in iterator:
@@ -386,7 +375,28 @@ def filter_func(iterator):
386375
with self.sql_conf({"spark.sql.pyspark.udf.profiler": "perf"}):
387376
df.mapInPandas(filter_func, df.schema).show()
388377

389-
self.assertEqual(0, len(self.profile_results), str(self.profile_results.keys()))
378+
self.assertEqual(1, len(self.profile_results), str(self.profile_results.keys()))
379+
380+
for id in self.profile_results:
381+
self.assert_udf_profile_present(udf_id=id, expected_line_count_prefix=2)
382+
383+
@unittest.skipIf(not have_pyarrow, pyarrow_requirement_message)
384+
def test_perf_profiler_map_in_arrow(self):
385+
import pyarrow as pa
386+
387+
df = self.spark.createDataFrame([(1, 21), (2, 30)], ("id", "age")).repartition(1)
388+
389+
def map_func(iterator: Iterator[pa.RecordBatch]) -> Iterator[pa.RecordBatch]:
390+
for batch in iterator:
391+
yield pa.RecordBatch.from_arrays(
392+
[batch.column("id"), pa.compute.add(batch.column("age"), 1)], ["id", "age"]
393+
)
394+
395+
with self.sql_conf({"spark.sql.pyspark.udf.profiler": "perf"}):
396+
df.mapInArrow(map_func, df.schema).show()
397+
398+
for id in self.profile_results:
399+
self.assert_udf_profile_present(udf_id=id, expected_line_count_prefix=2)
390400

391401
@unittest.skipIf(
392402
not have_pandas or not have_pyarrow,
@@ -575,6 +585,34 @@ def summarize(left, right):
575585
for id in self.profile_results:
576586
self.assert_udf_profile_present(udf_id=id, expected_line_count_prefix=2)
577587

588+
def test_perf_profiler_data_source(self):
589+
class TestDataSourceReader(DataSourceReader):
590+
def __init__(self, schema):
591+
self.schema = schema
592+
593+
def partitions(self):
594+
raise NotImplementedError
595+
596+
def read(self, partition):
597+
yield from ((1,), (2,), (3,))
598+
599+
class TestDataSource(DataSource):
600+
def schema(self):
601+
return "id long"
602+
603+
def reader(self, schema) -> "DataSourceReader":
604+
return TestDataSourceReader(schema)
605+
606+
self.spark.dataSource.register(TestDataSource)
607+
608+
with self.sql_conf({"spark.sql.pyspark.udf.profiler": "perf"}):
609+
self.spark.read.format("TestDataSource").load().collect()
610+
611+
self.assertEqual(1, len(self.profile_results), str(self.profile_results.keys()))
612+
613+
for id in self.profile_results:
614+
self.assert_udf_profile_present(udf_id=id, expected_line_count_prefix=4)
615+
578616
def test_perf_profiler_render(self):
579617
with self.sql_conf({"spark.sql.pyspark.udf.profiler": "perf"}):
580618
_do_computation(self.spark)

python/pyspark/tests/test_memory_profiler.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -341,12 +341,13 @@ def add2(x):
341341
not have_pandas or not have_pyarrow,
342342
cast(str, pandas_requirement_message or pyarrow_requirement_message),
343343
)
344-
def test_memory_profiler_pandas_udf_iterator_not_supported(self):
344+
def test_memory_profiler_pandas_udf_iterator(self):
345345
import pandas as pd
346346

347347
@pandas_udf("long")
348-
def add1(x):
349-
return x + 1
348+
def add1(iter: Iterator[pd.Series]) -> Iterator[pd.Series]:
349+
for s in iter:
350+
yield s + 1
350351

351352
@pandas_udf("long")
352353
def add2(iter: Iterator[pd.Series]) -> Iterator[pd.Series]:
@@ -359,7 +360,7 @@ def add2(iter: Iterator[pd.Series]) -> Iterator[pd.Series]:
359360
)
360361
df.collect()
361362

362-
self.assertEqual(1, len(self.profile_results), str(self.profile_results.keys()))
363+
self.assertEqual(3, len(self.profile_results), str(self.profile_results.keys()))
363364

364365
for id in self.profile_results:
365366
self.assert_udf_memory_profile_present(udf_id=id)
@@ -368,7 +369,7 @@ def add2(iter: Iterator[pd.Series]) -> Iterator[pd.Series]:
368369
not have_pandas or not have_pyarrow,
369370
cast(str, pandas_requirement_message or pyarrow_requirement_message),
370371
)
371-
def test_memory_profiler_map_in_pandas_not_supported(self):
372+
def test_memory_profiler_map_in_pandas(self):
372373
df = self.spark.createDataFrame([(1, 21), (2, 30)], ("id", "age"))
373374

374375
def filter_func(iterator):
@@ -378,7 +379,10 @@ def filter_func(iterator):
378379
with self.sql_conf({"spark.sql.pyspark.udf.profiler": "memory"}):
379380
df.mapInPandas(filter_func, df.schema).show()
380381

381-
self.assertEqual(0, len(self.profile_results), str(self.profile_results.keys()))
382+
self.assertEqual(1, len(self.profile_results), str(self.profile_results.keys()))
383+
384+
for id in self.profile_results:
385+
self.assert_udf_memory_profile_present(udf_id=id)
382386

383387
@unittest.skipIf(
384388
not have_pandas or not have_pyarrow,

python/pyspark/worker.py

Lines changed: 77 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1158,17 +1158,19 @@ def func(*args):
11581158
return f, args_offsets
11591159

11601160

1161-
def _supports_profiler(eval_type: int) -> bool:
1162-
return eval_type not in (
1161+
def _is_iter_based(eval_type: int) -> bool:
1162+
return eval_type in (
11631163
PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF,
11641164
PythonEvalType.SQL_SCALAR_ARROW_ITER_UDF,
11651165
PythonEvalType.SQL_MAP_PANDAS_ITER_UDF,
11661166
PythonEvalType.SQL_MAP_ARROW_ITER_UDF,
11671167
PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE,
1168+
PythonEvalType.SQL_GROUPED_MAP_ARROW_ITER_UDF,
1169+
PythonEvalType.SQL_GROUPED_MAP_PANDAS_ITER_UDF,
11681170
)
11691171

11701172

1171-
def wrap_perf_profiler(f, result_id):
1173+
def wrap_perf_profiler(f, eval_type, result_id):
11721174
import cProfile
11731175
import pstats
11741176

@@ -1178,38 +1180,89 @@ def wrap_perf_profiler(f, result_id):
11781180
SpecialAccumulatorIds.SQL_UDF_PROFIER, None, ProfileResultsParam
11791181
)
11801182

1181-
def profiling_func(*args, **kwargs):
1182-
with cProfile.Profile() as pr:
1183-
ret = f(*args, **kwargs)
1184-
st = pstats.Stats(pr)
1185-
st.stream = None # make it picklable
1186-
st.strip_dirs()
1183+
if _is_iter_based(eval_type):
1184+
1185+
def profiling_func(*args, **kwargs):
1186+
iterator = iter(f(*args, **kwargs))
1187+
pr = cProfile.Profile()
1188+
while True:
1189+
try:
1190+
with pr:
1191+
item = next(iterator)
1192+
yield item
1193+
except StopIteration:
1194+
break
1195+
1196+
st = pstats.Stats(pr)
1197+
st.stream = None # make it picklable
1198+
st.strip_dirs()
1199+
1200+
accumulator.add({result_id: (st, None)})
11871201

1188-
accumulator.add({result_id: (st, None)})
1202+
else:
1203+
1204+
def profiling_func(*args, **kwargs):
1205+
with cProfile.Profile() as pr:
1206+
ret = f(*args, **kwargs)
1207+
st = pstats.Stats(pr)
1208+
st.stream = None # make it picklable
1209+
st.strip_dirs()
11891210

1190-
return ret
1211+
accumulator.add({result_id: (st, None)})
1212+
1213+
return ret
11911214

11921215
return profiling_func
11931216

11941217

1195-
def wrap_memory_profiler(f, result_id):
1218+
def wrap_memory_profiler(f, eval_type, result_id):
11961219
from pyspark.sql.profiler import ProfileResultsParam
11971220
from pyspark.profiler import UDFLineProfilerV2
11981221

1222+
if not has_memory_profiler:
1223+
return f
1224+
11991225
accumulator = _deserialize_accumulator(
12001226
SpecialAccumulatorIds.SQL_UDF_PROFIER, None, ProfileResultsParam
12011227
)
12021228

1203-
def profiling_func(*args, **kwargs):
1204-
profiler = UDFLineProfilerV2()
1229+
if _is_iter_based(eval_type):
12051230

1206-
wrapped = profiler(f)
1207-
ret = wrapped(*args, **kwargs)
1208-
codemap_dict = {
1209-
filename: list(line_iterator) for filename, line_iterator in profiler.code_map.items()
1210-
}
1211-
accumulator.add({result_id: (None, codemap_dict)})
1212-
return ret
1231+
def profiling_func(*args, **kwargs):
1232+
profiler = UDFLineProfilerV2()
1233+
profiler.add_function(f)
1234+
1235+
iterator = iter(f(*args, **kwargs))
1236+
1237+
while True:
1238+
try:
1239+
with profiler:
1240+
item = next(iterator)
1241+
yield item
1242+
except StopIteration:
1243+
break
1244+
1245+
codemap_dict = {
1246+
filename: list(line_iterator)
1247+
for filename, line_iterator in profiler.code_map.items()
1248+
}
1249+
accumulator.add({result_id: (None, codemap_dict)})
1250+
1251+
else:
1252+
1253+
def profiling_func(*args, **kwargs):
1254+
profiler = UDFLineProfilerV2()
1255+
profiler.add_function(f)
1256+
1257+
with profiler:
1258+
ret = f(*args, **kwargs)
1259+
1260+
codemap_dict = {
1261+
filename: list(line_iterator)
1262+
for filename, line_iterator in profiler.code_map.items()
1263+
}
1264+
accumulator.add({result_id: (None, codemap_dict)})
1265+
return ret
12131266

12141267
return profiling_func
12151268

@@ -1254,17 +1307,12 @@ def read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index, profil
12541307
if profiler == "perf":
12551308
result_id = read_long(infile)
12561309

1257-
if _supports_profiler(eval_type):
1258-
profiling_func = wrap_perf_profiler(chained_func, result_id)
1259-
else:
1260-
profiling_func = chained_func
1310+
profiling_func = wrap_perf_profiler(chained_func, eval_type, result_id)
12611311

12621312
elif profiler == "memory":
12631313
result_id = read_long(infile)
1264-
if _supports_profiler(eval_type) and has_memory_profiler:
1265-
profiling_func = wrap_memory_profiler(chained_func, result_id)
1266-
else:
1267-
profiling_func = chained_func
1314+
1315+
profiling_func = wrap_memory_profiler(chained_func, eval_type, result_id)
12681316
else:
12691317
profiling_func = chained_func
12701318

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/UserDefinedPythonDataSource.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,8 @@ case class UserDefinedPythonDataSource(dataSourceCls: PythonFunction) {
172172
pythonRunnerConf,
173173
metrics,
174174
jobArtifactUUID,
175-
sessionUUID)
175+
sessionUUID,
176+
conf.pythonUDFProfiler)
176177
}
177178

178179
def createPythonMetrics(): Array[CustomMetric] = {

sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchEvaluatorFactory.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,8 @@ class MapInBatchEvaluatorFactory(
4141
pythonRunnerConf: Map[String, String],
4242
val pythonMetrics: Map[String, SQLMetric],
4343
jobArtifactUUID: Option[String],
44-
sessionUUID: Option[String])
44+
sessionUUID: Option[String],
45+
profiler: Option[String])
4546
extends PartitionEvaluatorFactory[InternalRow, InternalRow] {
4647

4748
override def createEvaluator(): PartitionEvaluator[InternalRow, InternalRow] =
@@ -74,7 +75,7 @@ class MapInBatchEvaluatorFactory(
7475
pythonMetrics,
7576
jobArtifactUUID,
7677
sessionUUID,
77-
None) with BatchedPythonArrowInput
78+
profiler) with BatchedPythonArrowInput
7879
val columnarBatchIter = pyRunner.compute(batchIter, context.partitionId(), context)
7980

8081
val unsafeProj = UnsafeProjection.create(output, output)

sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchExec.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,8 @@ trait MapInBatchExec extends UnaryExecNode with PythonSQLMetrics {
7070
pythonRunnerConf,
7171
pythonMetrics,
7272
jobArtifactUUID,
73-
sessionUUID)
73+
sessionUUID,
74+
conf.pythonUDFProfiler)
7475

7576
val rdd = if (isBarrier) {
7677
val rddBarrier = child.execute().barrier()

0 commit comments

Comments
 (0)