Skip to content

Commit

Permalink
respond to code review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
dtenedor committed Nov 13, 2023
1 parent dde7fa9 commit 768096d
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 11 deletions.
29 changes: 21 additions & 8 deletions python/pyspark/sql/tests/test_udtf.py
Original file line number Diff line number Diff line change
Expand Up @@ -2469,17 +2469,11 @@ def terminate(self):
)

def test_udtf_with_skip_rest_of_input_table_exception(self):
@udtf
@udtf(returnType="total: int")
class TestUDTF:
def __init__(self):
self._total = 0

@staticmethod
def analyze(_):
return AnalyzeResult(
schema=StructType().add("total", IntegerType()), withSinglePartition=True
)

def eval(self, _: Row):
self._total += 1
if self._total >= 4:
Expand All @@ -2490,18 +2484,37 @@ def terminate(self):

self.spark.udtf.register("test_udtf", TestUDTF)

# Run a test case including WITH SINGLE PARTITION on the UDTF call. The
# SkipRestOfInputTableException stops scanning rows after the fourth input row is consumed.
assertDataFrameEqual(
self.spark.sql(
"""
WITH t AS (
SELECT id FROM range(1, 21)
)
SELECT total
FROM test_udtf(TABLE(t))
FROM test_udtf(TABLE(t) WITH SINGLE PARTITION)
"""
),
[Row(total=4)],
)
# Run a test case including WITH SINGLE PARTITION on the UDTF call. The
# SkipRestOfInputTableException stops scanning rows for each of the two partitions
# separately.
assertDataFrameEqual(
self.spark.sql(
"""
WITH t AS (
SELECT id FROM range(1, 21)
)
SELECT id / 10 AS id_divided_by_ten, total
FROM test_udtf(TABLE(t) PARTITION BY id / 10)
ORDER BY ALL
"""
),
[Row(id_divided_by_ten=0, total=4),
Row(id_divided_by_ten=1, total=4)],
)


class UDTFTests(BaseUDTFTestsMixin, ReusedSQLTestCase):
Expand Down
8 changes: 5 additions & 3 deletions python/pyspark/sql/udtf.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,10 +125,12 @@ class AnalyzeResult:
orderBy: Sequence[OrderingColumn] = field(default_factory=tuple)


# This represents an exception that the 'eval' method may raise to indicate that it is done
# consuming rows from the current partition of the input table. Then the UDTF's 'terminate' method
# runs (if any).
class SkipRestOfInputTableException(Exception):
"""
This represents an exception that the 'eval' method may raise to indicate that it is done
consuming rows from the current partition of the input table. Then the UDTF's 'terminate'
method runs (if any).
"""
pass


Expand Down

0 comments on commit 768096d

Please sign in to comment.