Skip to content

Commit

Permalink
Add test cases for AsyncSparkCursor
Browse files Browse the repository at this point in the history
  • Loading branch information
laughingman7743 committed Jan 8, 2024
1 parent 4b0cda5 commit e976d98
Show file tree
Hide file tree
Showing 4 changed files with 175 additions and 18 deletions.
22 changes: 14 additions & 8 deletions pyathena/spark/async_spark_cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,17 +26,23 @@ def close(self, wait: bool = False) -> None:
def calculation_execution(self, query_id: str) -> "Future[AthenaCalculationExecution]":
return self._executor.submit(self._get_calculation_execution, query_id)

def get_std_out(self, query_id: str) -> Optional[str]:
calculation_execution = self._get_calculation_execution(query_id)
if not calculation_execution or not calculation_execution.std_out_s3_uri:
def get_std_out(
self, calculation_execution: AthenaCalculationExecution
) -> "Optional[Future[str]]":
if not calculation_execution.std_out_s3_uri:
return None
return self._read_s3_file_as_text(calculation_execution.std_out_s3_uri)
return self._executor.submit(
self._read_s3_file_as_text, calculation_execution.std_out_s3_uri
)

def get_std_error(self, query_id: str) -> Optional[str]:
calculation_execution = self._get_calculation_execution(query_id)
if not calculation_execution or not calculation_execution.std_error_s3_uri:
def get_std_error(
self, calculation_execution: AthenaCalculationExecution
) -> "Optional[Future[str]]":
if not calculation_execution.std_error_s3_uri:
return None
return self._read_s3_file_as_text(calculation_execution.std_error_s3_uri)
return self._executor.submit(
self._read_s3_file_as_text, calculation_execution.std_error_s3_uri
)

def poll(self, query_id: str) -> "Future[AthenaCalculationExecution]":
return cast(
Expand Down
6 changes: 6 additions & 0 deletions tests/pyathena/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,13 +174,19 @@ def async_arrow_cursor(request):
def spark_cursor(request):
from pyathena.spark.spark_cursor import SparkCursor

if not hasattr(request, "param"):
setattr(request, "param", {})
request.param.update({"work_group": ENV.spark_work_group})
yield from _cursor(SparkCursor, request)


@pytest.fixture
def async_spark_cursor(request):
from pyathena.spark.async_spark_cursor import AsyncSparkCursor

if not hasattr(request, "param"):
setattr(request, "param", {})
request.param.update({"work_group": ENV.spark_work_group})
yield from _cursor(AsyncSparkCursor, request)


Expand Down
126 changes: 126 additions & 0 deletions tests/pyathena/spark/test_async_spark_cursor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import textwrap
import time
from random import randint

from pyathena.model import AthenaCalculationExecution
from tests import ENV


class TestAsyncSparkCursor:
def test_spark_dataframe(self, async_spark_cursor):
query_id, future = async_spark_cursor.execute(
textwrap.dedent(
f"""
df = spark.read.format("csv") \\
.option("header", "true") \\
.option("inferSchema", "true") \\
.load("{ENV.s3_staging_dir}{ENV.schema}/spark_group_by/spark_group_by.csv")
"""
),
description="test description",
)
calculation_execution = future.result()
assert calculation_execution.session_id
assert query_id == calculation_execution.calculation_id
assert calculation_execution.description == "test description"
assert calculation_execution.working_directory
assert calculation_execution.state == AthenaCalculationExecution.STATE_COMPLETED
assert calculation_execution.state_change_reason is None
assert calculation_execution.submission_date_time
assert calculation_execution.completion_date_time
assert calculation_execution.dpu_execution_in_millis
assert calculation_execution.progress
assert calculation_execution.std_out_s3_uri
assert calculation_execution.std_error_s3_uri
assert calculation_execution.result_s3_uri
assert calculation_execution.result_type

query_id, future = async_spark_cursor.execute(
textwrap.dedent(
"""
from pyspark.sql.functions import sum
df_count = df.groupBy("name").agg(sum("count").alias("sum"))
df_count.show()
"""
)
)
calculation_execution = future.result()
std_out = async_spark_cursor.get_std_out(calculation_execution).result()
assert (
std_out
== textwrap.dedent(
"""
+----+---+
|name|sum|
+----+---+
| bar| 5|
| foo| 5|
+----+---+
"""
).strip()
)

def test_spark_sql(self, async_spark_cursor):
query_id, future = async_spark_cursor.execute(
textwrap.dedent(
f"""
spark.sql("SELECT * FROM {ENV.schema}.one_row").show()
"""
)
)
calculation_execution = future.result()
std_out = async_spark_cursor.get_std_out(calculation_execution).result()
assert (
std_out
== textwrap.dedent(
"""
+--------------+
|number_of_rows|
+--------------+
| 1|
+--------------+
"""
).strip()
)

def test_failed(self, async_spark_cursor):
query_id, future = async_spark_cursor.execute(
textwrap.dedent(
"""
foobar
"""
)
)
calculation_execution = future.result()
assert calculation_execution.state == AthenaCalculationExecution.STATE_FAILED
std_error = async_spark_cursor.get_std_error(calculation_execution).result()
assert (
std_error
== textwrap.dedent(
"""
File "<stdin>", line 2, in <module>
NameError: name 'foobar' is not defined
"""
).strip()
)

def test_cancel(self, async_spark_cursor):
query_id, future = async_spark_cursor.execute(
textwrap.dedent(
f"""
spark.sql(
'''
SELECT a.a * rand(), b.a * rand()
FROM {ENV.schema}.many_rows a
CROSS JOIN {ENV.schema}.many_rows b
'''
)
"""
)
)
time.sleep(randint(5, 10))
async_spark_cursor.cancel(query_id)
calculation_execution = future.result()
assert calculation_execution.state == AthenaCalculationExecution.STATE_CANCELED
39 changes: 29 additions & 10 deletions tests/pyathena/spark/test_spark_cursor.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import textwrap
import time
from concurrent.futures import ThreadPoolExecutor
from random import randint

import pytest

from pyathena import OperationalError
from pyathena import DatabaseError, OperationalError
from pyathena.model import AthenaCalculationExecution
from tests import ENV


class TestSparkCursor:
@pytest.mark.parametrize(
"spark_cursor", [{"work_group": ENV.spark_work_group}], indirect=["spark_cursor"]
)
def test_spark_dataframe(self, spark_cursor):
spark_cursor.execute(
textwrap.dedent(
Expand Down Expand Up @@ -76,9 +76,6 @@ def test_spark_dataframe(self, spark_cursor):
)

@pytest.mark.depends(on="test_spark_dataframe")
@pytest.mark.parametrize(
"spark_cursor", [{"work_group": ENV.spark_work_group}], indirect=["spark_cursor"]
)
def test_spark_sql(self, spark_cursor):
spark_cursor.execute(
textwrap.dedent(
Expand Down Expand Up @@ -108,9 +105,6 @@ def test_spark_sql(self, spark_cursor):
)
)

@pytest.mark.parametrize(
"spark_cursor", [{"work_group": ENV.spark_work_group}], indirect=["spark_cursor"]
)
def test_failed(self, spark_cursor):
with pytest.raises(OperationalError):
spark_cursor.execute(
Expand All @@ -130,3 +124,28 @@ def test_failed(self, spark_cursor):
"""
).strip()
)

def test_cancel(self, spark_cursor):
def cancel(c):
time.sleep(randint(5, 10))
c.cancel()

with ThreadPoolExecutor(max_workers=1) as executor:
executor.submit(cancel, spark_cursor)

pytest.raises(
DatabaseError,
lambda: spark_cursor.execute(
textwrap.dedent(
f"""
spark.sql(
'''
SELECT a.a * rand(), b.a * rand()
FROM {ENV.schema}.many_rows a
CROSS JOIN {ENV.schema}.many_rows b
'''
)
"""
)
),
)

0 comments on commit e976d98

Please sign in to comment.