Skip to content

Commit b41783f

Browse files
committed
Initial implementation of default BytesArray -> Bytes. WIP.
1 parent 8c422f9 commit b41783f

20 files changed

+274
-59
lines changed

python/pyspark/serializers.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -545,6 +545,44 @@ def __repr__(self):
545545
return "CompressedSerializer(%s)" % self.serializer
546546

547547

548+
class BinaryConvertingSerializer(Serializer):
549+
"""
550+
Converts bytearray to bytes for binary data when binary_as_bytes is enabled
551+
"""
552+
553+
def __init__(self, serializer, binary_as_bytes=False):
554+
self.serializer = serializer
555+
self.binary_as_bytes = binary_as_bytes
556+
557+
def _convert_binary(self, obj):
558+
"""Recursively convert bytearray to bytes in data structures"""
559+
if not self.binary_as_bytes:
560+
return obj
561+
562+
if isinstance(obj, bytearray):
563+
return bytes(obj)
564+
elif isinstance(obj, (list, tuple)):
565+
converted = [self._convert_binary(item) for item in obj]
566+
return type(obj)(converted)
567+
elif isinstance(obj, dict):
568+
return {key: self._convert_binary(value) for key, value in obj.items()}
569+
else:
570+
return obj
571+
572+
def dump_stream(self, iterator, stream):
573+
self.serializer.dump_stream(iterator, stream)
574+
575+
def load_stream(self, stream):
576+
for obj in self.serializer.load_stream(stream):
577+
yield self._convert_binary(obj)
578+
579+
def __repr__(self):
580+
return "BinaryConvertingSerializer(%s, binary_as_bytes=%s)" % (
581+
str(self.serializer),
582+
self.binary_as_bytes,
583+
)
584+
585+
548586
class UTF8Deserializer(Serializer):
549587

550588
"""

python/pyspark/sql/avro/functions.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def from_avro(
6969
>>> df = spark.createDataFrame(data, ("key", "value"))
7070
>>> avroDf = df.select(to_avro(df.value).alias("avro"))
7171
>>> avroDf.collect()
72-
[Row(avro=bytearray(b'\\x00\\x00\\x04\\x00\\nAlice'))]
72+
[Row(avro=b'\\x00\\x00\\x04\\x00\\nAlice')]
7373
7474
>>> jsonFormatSchema = '''{"type":"record","name":"topLevelRecord","fields":
7575
... [{"name":"avro","type":[{"type":"record","name":"value","namespace":"topLevelRecord",
@@ -141,12 +141,12 @@ def to_avro(data: "ColumnOrName", jsonFormatSchema: str = "") -> Column:
141141
>>> data = ['SPADES']
142142
>>> df = spark.createDataFrame(data, "string")
143143
>>> df.select(to_avro(df.value).alias("suite")).collect()
144-
[Row(suite=bytearray(b'\\x00\\x0cSPADES'))]
144+
[Row(suite=b'\\x00\\x0cSPADES')]
145145
146146
>>> jsonFormatSchema = '''["null", {"type": "enum", "name": "value",
147147
... "symbols": ["SPADES", "HEARTS", "DIAMONDS", "CLUBS"]}]'''
148148
>>> df.select(to_avro(df.value, jsonFormatSchema).alias("suite")).collect()
149-
[Row(suite=bytearray(b'\\x02\\x00'))]
149+
[Row(suite=b'\\x02\\x00')]
150150
"""
151151
from py4j.java_gateway import JVMView
152152
from pyspark.sql.classic.column import _to_java_column

python/pyspark/sql/conversion.py

Lines changed: 45 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,21 @@
4646
import pyarrow as pa
4747

4848

49+
def _should_use_bytes_for_binary(binary_as_bytes: Optional[bool] = None) -> bool:
50+
"""Check if BINARY type should be converted to bytes instead of bytearray."""
51+
if binary_as_bytes is not None:
52+
return binary_as_bytes
53+
54+
from pyspark.sql import SparkSession
55+
56+
spark = SparkSession.getActiveSession()
57+
if spark is not None:
58+
v = spark.conf.get("spark.sql.execution.pyspark.binaryAsBytes.enabled", "true")
59+
return str(v).lower() == "true"
60+
61+
return True
62+
63+
4964
class LocalDataToArrowConversion:
5065
"""
5166
Conversion from local data (except pandas DataFrame and numpy ndarray) to Arrow.
@@ -518,13 +533,16 @@ def _create_converter(dataType: DataType) -> Callable:
518533
@overload
519534
@staticmethod
520535
def _create_converter(
521-
dataType: DataType, *, none_on_identity: bool = True
536+
dataType: DataType, *, none_on_identity: bool = True, binary_as_bytes: Optional[bool] = None
522537
) -> Optional[Callable]:
523538
pass
524539

525540
@staticmethod
526541
def _create_converter(
527-
dataType: DataType, *, none_on_identity: bool = False
542+
dataType: DataType,
543+
*,
544+
none_on_identity: bool = False,
545+
binary_as_bytes: Optional[bool] = None,
528546
) -> Optional[Callable]:
529547
assert dataType is not None and isinstance(dataType, DataType)
530548

@@ -542,7 +560,9 @@ def _create_converter(
542560
dedup_field_names = _dedup_names(field_names)
543561

544562
field_convs = [
545-
ArrowTableToRowsConversion._create_converter(f.dataType, none_on_identity=True)
563+
ArrowTableToRowsConversion._create_converter(
564+
f.dataType, none_on_identity=True, binary_as_bytes=binary_as_bytes
565+
)
546566
for f in dataType.fields
547567
]
548568

@@ -564,7 +584,7 @@ def convert_struct(value: Any) -> Any:
564584

565585
elif isinstance(dataType, ArrayType):
566586
element_conv = ArrowTableToRowsConversion._create_converter(
567-
dataType.elementType, none_on_identity=True
587+
dataType.elementType, none_on_identity=True, binary_as_bytes=binary_as_bytes
568588
)
569589

570590
if element_conv is None:
@@ -589,10 +609,10 @@ def convert_array(value: Any) -> Any:
589609

590610
elif isinstance(dataType, MapType):
591611
key_conv = ArrowTableToRowsConversion._create_converter(
592-
dataType.keyType, none_on_identity=True
612+
dataType.keyType, none_on_identity=True, binary_as_bytes=binary_as_bytes
593613
)
594614
value_conv = ArrowTableToRowsConversion._create_converter(
595-
dataType.valueType, none_on_identity=True
615+
dataType.valueType, none_on_identity=True, binary_as_bytes=binary_as_bytes
596616
)
597617

598618
if key_conv is None:
@@ -646,7 +666,10 @@ def convert_binary(value: Any) -> Any:
646666
return None
647667
else:
648668
assert isinstance(value, bytes)
649-
return bytearray(value)
669+
if _should_use_bytes_for_binary(binary_as_bytes):
670+
return value
671+
else:
672+
return bytearray(value)
650673

651674
return convert_binary
652675

@@ -676,7 +699,7 @@ def convert_timestample_ntz(value: Any) -> Any:
676699
udt: UserDefinedType = dataType
677700

678701
conv = ArrowTableToRowsConversion._create_converter(
679-
udt.sqlType(), none_on_identity=True
702+
udt.sqlType(), none_on_identity=True, binary_as_bytes=binary_as_bytes
680703
)
681704

682705
if conv is None:
@@ -722,20 +745,28 @@ def convert_variant(value: Any) -> Any:
722745
@overload
723746
@staticmethod
724747
def convert( # type: ignore[overload-overlap]
725-
table: "pa.Table", schema: StructType
748+
table: "pa.Table", schema: StructType, *, binary_as_bytes: Optional[bool] = None
726749
) -> List[Row]:
727750
pass
728751

729752
@overload
730753
@staticmethod
731754
def convert(
732-
table: "pa.Table", schema: StructType, *, return_as_tuples: bool = True
755+
table: "pa.Table",
756+
schema: StructType,
757+
*,
758+
return_as_tuples: bool = True,
759+
binary_as_bytes: Optional[bool] = None,
733760
) -> List[tuple]:
734761
pass
735762

736763
@staticmethod # type: ignore[misc]
737764
def convert(
738-
table: "pa.Table", schema: StructType, *, return_as_tuples: bool = False
765+
table: "pa.Table",
766+
schema: StructType,
767+
*,
768+
return_as_tuples: bool = False,
769+
binary_as_bytes: Optional[bool] = None,
739770
) -> List[Union[Row, tuple]]:
740771
require_minimum_pyarrow_version()
741772
import pyarrow as pa
@@ -748,7 +779,9 @@ def convert(
748779

749780
if len(fields) > 0:
750781
field_converters = [
751-
ArrowTableToRowsConversion._create_converter(f.dataType, none_on_identity=True)
782+
ArrowTableToRowsConversion._create_converter(
783+
f.dataType, none_on_identity=True, binary_as_bytes=binary_as_bytes
784+
)
752785
for f in schema.fields
753786
]
754787

python/pyspark/sql/functions/builtin.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16591,15 +16591,15 @@ def to_binary(col: "ColumnOrName", format: Optional["ColumnOrName"] = None) -> C
1659116591

1659216592
>>> import pyspark.sql.functions as sf
1659316593
>>> df = spark.createDataFrame([("abc",)], ["e"])
16594-
>>> df.select(sf.try_to_binary(df.e, sf.lit("utf-8")).alias('r')).collect()
16595-
[Row(r=bytearray(b'abc'))]
16594+
>>> df.select(sf.to_binary(df.e, sf.lit("utf-8")).alias('r')).collect()
16595+
[Row(r=b'abc')]
1659616596

1659716597
Example 2: Convert string to a timestamp without encoding specified
1659816598

1659916599
>>> import pyspark.sql.functions as sf
1660016600
>>> df = spark.createDataFrame([("414243",)], ["e"])
16601-
>>> df.select(sf.try_to_binary(df.e).alias('r')).collect()
16602-
[Row(r=bytearray(b'ABC'))]
16601+
>>> df.select(sf.to_binary(df.e).alias('r')).collect()
16602+
[Row(r=b'ABC')]
1660316603
"""
1660416604
if format is not None:
1660516605
return _invoke_function_over_columns("to_binary", col, format)
@@ -17615,14 +17615,14 @@ def try_to_binary(col: "ColumnOrName", format: Optional["ColumnOrName"] = None)
1761517615
>>> import pyspark.sql.functions as sf
1761617616
>>> df = spark.createDataFrame([("abc",)], ["e"])
1761717617
>>> df.select(sf.try_to_binary(df.e, sf.lit("utf-8")).alias('r')).collect()
17618-
[Row(r=bytearray(b'abc'))]
17618+
[Row(r=b'abc')]
1761917619

1762017620
Example 2: Convert string to a timestamp without encoding specified
1762117621

1762217622
>>> import pyspark.sql.functions as sf
1762317623
>>> df = spark.createDataFrame([("414243",)], ["e"])
1762417624
>>> df.select(sf.try_to_binary(df.e).alias('r')).collect()
17625-
[Row(r=bytearray(b'ABC'))]
17625+
[Row(r=b'ABC')]
1762617626

1762717627
Example 3: Converion failure results in NULL when ANSI mode is on
1762817628

python/pyspark/sql/pandas/serializers.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -350,11 +350,12 @@ class ArrowStreamPandasSerializer(ArrowStreamSerializer):
350350
This has performance penalties.
351351
"""
352352

353-
def __init__(self, timezone, safecheck, int_to_decimal_coercion_enabled):
353+
def __init__(self, timezone, safecheck, int_to_decimal_coercion_enabled, binary_as_bytes=None):
354354
super(ArrowStreamPandasSerializer, self).__init__()
355355
self._timezone = timezone
356356
self._safecheck = safecheck
357357
self._int_to_decimal_coercion_enabled = int_to_decimal_coercion_enabled
358+
self._binary_as_bytes = binary_as_bytes
358359
self._converter_cache = {}
359360

360361
@staticmethod
@@ -583,9 +584,10 @@ def __init__(
583584
arrow_cast=False,
584585
input_types=None,
585586
int_to_decimal_coercion_enabled=False,
587+
binary_as_bytes=None,
586588
):
587589
super(ArrowStreamPandasUDFSerializer, self).__init__(
588-
timezone, safecheck, int_to_decimal_coercion_enabled
590+
timezone, safecheck, int_to_decimal_coercion_enabled, binary_as_bytes
589591
)
590592
self._assign_cols_by_name = assign_cols_by_name
591593
self._df_for_struct = df_for_struct
@@ -782,12 +784,14 @@ def __init__(
782784
safecheck,
783785
assign_cols_by_name,
784786
arrow_cast,
787+
binary_as_bytes=None,
785788
):
786789
super(ArrowStreamArrowUDFSerializer, self).__init__()
787790
self._timezone = timezone
788791
self._safecheck = safecheck
789792
self._assign_cols_by_name = assign_cols_by_name
790793
self._arrow_cast = arrow_cast
794+
self._binary_as_bytes = binary_as_bytes
791795

792796
def _create_array(self, arr, arrow_type, arrow_cast):
793797
import pyarrow as pa
@@ -862,12 +866,14 @@ def __init__(
862866
safecheck,
863867
input_types,
864868
int_to_decimal_coercion_enabled=False,
869+
binary_as_bytes=None,
865870
):
866871
super().__init__(
867872
timezone=timezone,
868873
safecheck=safecheck,
869874
assign_cols_by_name=False,
870875
arrow_cast=True,
876+
binary_as_bytes=binary_as_bytes,
871877
)
872878
self._input_types = input_types
873879
self._int_to_decimal_coercion_enabled = int_to_decimal_coercion_enabled
@@ -887,7 +893,9 @@ def load_stream(self, stream):
887893
List of columns containing list of Python values.
888894
"""
889895
converters = [
890-
ArrowTableToRowsConversion._create_converter(dt, none_on_identity=True)
896+
ArrowTableToRowsConversion._create_converter(
897+
dt, none_on_identity=True, binary_as_bytes=self._binary_as_bytes
898+
)
891899
for dt in self._input_types
892900
]
893901

@@ -949,7 +957,14 @@ class ArrowStreamPandasUDTFSerializer(ArrowStreamPandasUDFSerializer):
949957
Serializer used by Python worker to evaluate Arrow-optimized Python UDTFs.
950958
"""
951959

952-
def __init__(self, timezone, safecheck, input_types, int_to_decimal_coercion_enabled):
960+
def __init__(
961+
self,
962+
timezone,
963+
safecheck,
964+
input_types,
965+
int_to_decimal_coercion_enabled,
966+
binary_as_bytes=None,
967+
):
953968
super(ArrowStreamPandasUDTFSerializer, self).__init__(
954969
timezone=timezone,
955970
safecheck=safecheck,
@@ -972,6 +987,7 @@ def __init__(self, timezone, safecheck, input_types, int_to_decimal_coercion_ena
972987
input_types=input_types,
973988
# Enable additional coercions for UDTF serialization
974989
int_to_decimal_coercion_enabled=int_to_decimal_coercion_enabled,
990+
binary_as_bytes=binary_as_bytes,
975991
)
976992
self._converter_map = dict()
977993

python/pyspark/sql/tests/arrow/test_arrow.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1781,6 +1781,47 @@ def test_createDataFrame_arrow_fixed_size_binary(self):
17811781
df = self.spark.createDataFrame(t)
17821782
self.assertIsInstance(df.schema["fsb"].dataType, BinaryType)
17831783

1784+
def test_binary_type_default_bytes_behavior(self):
1785+
"""Test that binary values are returned as bytes by default"""
1786+
df = self.spark.createDataFrame([(bytearray(b"test"),)], ["binary_col"])
1787+
collected = df.collect()
1788+
self.assertIsInstance(collected[0].binary_col, bytes)
1789+
self.assertEqual(collected[0].binary_col, b"test")
1790+
1791+
def test_binary_type_config_enabled_bytes(self):
1792+
"""Test that binary values are returned as bytes when config is enabled"""
1793+
with self.sql_conf({"spark.sql.execution.pyspark.binaryAsBytes.enabled": "true"}):
1794+
df = self.spark.createDataFrame([(bytearray(b"test"),)], ["binary_col"])
1795+
collected = df.collect()
1796+
self.assertIsInstance(collected[0].binary_col, bytes)
1797+
self.assertEqual(collected[0].binary_col, b"test")
1798+
1799+
def test_binary_type_config_disabled_bytearray(self):
1800+
"""Test that binary values are returned as bytearray when config is disabled"""
1801+
with self.sql_conf({"spark.sql.execution.pyspark.binaryAsBytes.enabled": "false"}):
1802+
df = self.spark.createDataFrame([(bytearray(b"test"),)], ["binary_col"])
1803+
collected = df.collect()
1804+
self.assertIsInstance(collected[0].binary_col, bytearray)
1805+
self.assertEqual(collected[0].binary_col, bytearray(b"test"))
1806+
1807+
def test_binary_type_to_local_iterator_bytes_mode(self):
1808+
"""Test binary type with toLocalIterator when bytes mode is enabled"""
1809+
with self.sql_conf({"spark.sql.execution.pyspark.binaryAsBytes.enabled": "true"}):
1810+
df = self.spark.createDataFrame([(b"test1",), (b"test2",)], ["binary_col"])
1811+
local_iter = df.toLocalIterator()
1812+
rows = list(local_iter)
1813+
for row in rows:
1814+
self.assertIsInstance(row.binary_col, bytes)
1815+
1816+
def test_binary_type_to_local_iterator_bytearray_mode(self):
1817+
"""Test binary type with toLocalIterator when bytearray mode is enabled"""
1818+
with self.sql_conf({"spark.sql.execution.pyspark.binaryAsBytes.enabled": "false"}):
1819+
df = self.spark.createDataFrame([(b"test1",), (b"test2",)], ["binary_col"])
1820+
local_iter = df.toLocalIterator()
1821+
rows = list(local_iter)
1822+
for row in rows:
1823+
self.assertIsInstance(row.binary_col, bytearray)
1824+
17841825
def test_createDataFrame_arrow_fixed_size_list(self):
17851826
a = pa.array([[-1, 3]] * 5, type=pa.list_(pa.int32(), 2))
17861827
t = pa.table([a], ["fsl"])

0 commit comments

Comments
 (0)