Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/sql-ref-datatypes.md
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ from pyspark.sql.types import *
|**StringType**|str|StringType()|
|**CharType(length)**|str|CharType(length)|
|**VarcharType(length)**|str|VarcharType(length)|
|**BinaryType**|bytearray|BinaryType()|
|**BinaryType**|bytes|BinaryType()|
|**BooleanType**|bool|BooleanType()|
|**TimestampType**|datetime.datetime|TimestampType()|
|**TimestampNTZType**|datetime.datetime|TimestampNTZType()|
Expand Down
2 changes: 1 addition & 1 deletion python/docs/source/tutorial/sql/type_conversions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ All Conversions
- string
- StringType()
* - **BinaryType**
- bytearray
- bytes
- BinaryType()
* - **BooleanType**
- bool
Expand Down
38 changes: 38 additions & 0 deletions python/pyspark/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,6 +545,44 @@ def __repr__(self):
return "CompressedSerializer(%s)" % self.serializer


class BinaryConvertingSerializer(Serializer):
"""
Converts bytearray to bytes for binary data when binary_as_bytes is enabled
"""

def __init__(self, serializer, binary_as_bytes=False):
self.serializer = serializer
self.binary_as_bytes = binary_as_bytes

def _convert_binary(self, obj):
"""Recursively convert bytearray to bytes in data structures"""
if not self.binary_as_bytes:
return obj

if isinstance(obj, bytearray):
return bytes(obj)
elif isinstance(obj, (list, tuple)):
converted = [self._convert_binary(item) for item in obj]
return type(obj)(converted)
elif isinstance(obj, dict):
return {key: self._convert_binary(value) for key, value in obj.items()}
else:
return obj

def dump_stream(self, iterator, stream):
self.serializer.dump_stream(iterator, stream)

def load_stream(self, stream):
for obj in self.serializer.load_stream(stream):
yield self._convert_binary(obj)

def __repr__(self):
return "BinaryConvertingSerializer(%s, binary_as_bytes=%s)" % (
str(self.serializer),
self.binary_as_bytes,
)


class UTF8Deserializer(Serializer):

"""
Expand Down
6 changes: 3 additions & 3 deletions python/pyspark/sql/avro/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def from_avro(
>>> df = spark.createDataFrame(data, ("key", "value"))
>>> avroDf = df.select(to_avro(df.value).alias("avro"))
>>> avroDf.collect()
[Row(avro=bytearray(b'\\x00\\x00\\x04\\x00\\nAlice'))]
[Row(avro=b'\\x00\\x00\\x04\\x00\\nAlice')]

>>> jsonFormatSchema = '''{"type":"record","name":"topLevelRecord","fields":
... [{"name":"avro","type":[{"type":"record","name":"value","namespace":"topLevelRecord",
Expand Down Expand Up @@ -141,12 +141,12 @@ def to_avro(data: "ColumnOrName", jsonFormatSchema: str = "") -> Column:
>>> data = ['SPADES']
>>> df = spark.createDataFrame(data, "string")
>>> df.select(to_avro(df.value).alias("suite")).collect()
[Row(suite=bytearray(b'\\x00\\x0cSPADES'))]
[Row(suite=b'\\x00\\x0cSPADES')]

>>> jsonFormatSchema = '''["null", {"type": "enum", "name": "value",
... "symbols": ["SPADES", "HEARTS", "DIAMONDS", "CLUBS"]}]'''
>>> df.select(to_avro(df.value, jsonFormatSchema).alias("suite")).collect()
[Row(suite=bytearray(b'\\x02\\x00'))]
[Row(suite=b'\\x02\\x00')]
"""
from py4j.java_gateway import JVMView
from pyspark.sql.classic.column import _to_java_column
Expand Down
57 changes: 45 additions & 12 deletions python/pyspark/sql/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,21 @@
import pyarrow as pa


def _should_use_bytes_for_binary(binary_as_bytes: Optional[bool] = None) -> bool:
"""Check if BINARY type should be converted to bytes instead of bytearray."""
if binary_as_bytes is not None:
return binary_as_bytes

from pyspark.sql import SparkSession

spark = SparkSession.getActiveSession()
if spark is not None:
v = spark.conf.get("spark.sql.execution.pyspark.binaryAsBytes.enabled", "true")
return str(v).lower() == "true"

return True


class LocalDataToArrowConversion:
"""
Conversion from local data (except pandas DataFrame and numpy ndarray) to Arrow.
Expand Down Expand Up @@ -518,13 +533,16 @@ def _create_converter(dataType: DataType) -> Callable:
@overload
@staticmethod
def _create_converter(
dataType: DataType, *, none_on_identity: bool = True
dataType: DataType, *, none_on_identity: bool = True, binary_as_bytes: Optional[bool] = None
) -> Optional[Callable]:
pass

@staticmethod
def _create_converter(
dataType: DataType, *, none_on_identity: bool = False
dataType: DataType,
*,
none_on_identity: bool = False,
binary_as_bytes: Optional[bool] = None,
) -> Optional[Callable]:
assert dataType is not None and isinstance(dataType, DataType)

Expand All @@ -542,7 +560,9 @@ def _create_converter(
dedup_field_names = _dedup_names(field_names)

field_convs = [
ArrowTableToRowsConversion._create_converter(f.dataType, none_on_identity=True)
ArrowTableToRowsConversion._create_converter(
f.dataType, none_on_identity=True, binary_as_bytes=binary_as_bytes
)
for f in dataType.fields
]

Expand All @@ -564,7 +584,7 @@ def convert_struct(value: Any) -> Any:

elif isinstance(dataType, ArrayType):
element_conv = ArrowTableToRowsConversion._create_converter(
dataType.elementType, none_on_identity=True
dataType.elementType, none_on_identity=True, binary_as_bytes=binary_as_bytes
)

if element_conv is None:
Expand All @@ -589,10 +609,10 @@ def convert_array(value: Any) -> Any:

elif isinstance(dataType, MapType):
key_conv = ArrowTableToRowsConversion._create_converter(
dataType.keyType, none_on_identity=True
dataType.keyType, none_on_identity=True, binary_as_bytes=binary_as_bytes
)
value_conv = ArrowTableToRowsConversion._create_converter(
dataType.valueType, none_on_identity=True
dataType.valueType, none_on_identity=True, binary_as_bytes=binary_as_bytes
)

if key_conv is None:
Expand Down Expand Up @@ -646,7 +666,10 @@ def convert_binary(value: Any) -> Any:
return None
else:
assert isinstance(value, bytes)
return bytearray(value)
if _should_use_bytes_for_binary(binary_as_bytes):
return value
else:
return bytearray(value)

return convert_binary

Expand Down Expand Up @@ -676,7 +699,7 @@ def convert_timestample_ntz(value: Any) -> Any:
udt: UserDefinedType = dataType

conv = ArrowTableToRowsConversion._create_converter(
udt.sqlType(), none_on_identity=True
udt.sqlType(), none_on_identity=True, binary_as_bytes=binary_as_bytes
)

if conv is None:
Expand Down Expand Up @@ -722,20 +745,28 @@ def convert_variant(value: Any) -> Any:
@overload
@staticmethod
def convert( # type: ignore[overload-overlap]
table: "pa.Table", schema: StructType
table: "pa.Table", schema: StructType, *, binary_as_bytes: Optional[bool] = None
) -> List[Row]:
pass

@overload
@staticmethod
def convert(
table: "pa.Table", schema: StructType, *, return_as_tuples: bool = True
table: "pa.Table",
schema: StructType,
*,
return_as_tuples: bool = True,
binary_as_bytes: Optional[bool] = None,
) -> List[tuple]:
pass

@staticmethod # type: ignore[misc]
def convert(
table: "pa.Table", schema: StructType, *, return_as_tuples: bool = False
table: "pa.Table",
schema: StructType,
*,
return_as_tuples: bool = False,
binary_as_bytes: Optional[bool] = None,
) -> List[Union[Row, tuple]]:
require_minimum_pyarrow_version()
import pyarrow as pa
Expand All @@ -748,7 +779,9 @@ def convert(

if len(fields) > 0:
field_converters = [
ArrowTableToRowsConversion._create_converter(f.dataType, none_on_identity=True)
ArrowTableToRowsConversion._create_converter(
f.dataType, none_on_identity=True, binary_as_bytes=binary_as_bytes
)
for f in schema.fields
]

Expand Down
12 changes: 6 additions & 6 deletions python/pyspark/sql/functions/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -16591,15 +16591,15 @@ def to_binary(col: "ColumnOrName", format: Optional["ColumnOrName"] = None) -> C

>>> import pyspark.sql.functions as sf
>>> df = spark.createDataFrame([("abc",)], ["e"])
>>> df.select(sf.try_to_binary(df.e, sf.lit("utf-8")).alias('r')).collect()
[Row(r=bytearray(b'abc'))]
>>> df.select(sf.to_binary(df.e, sf.lit("utf-8")).alias('r')).collect()
[Row(r=b'abc')]

Example 2: Convert string to a timestamp without encoding specified

>>> import pyspark.sql.functions as sf
>>> df = spark.createDataFrame([("414243",)], ["e"])
>>> df.select(sf.try_to_binary(df.e).alias('r')).collect()
[Row(r=bytearray(b'ABC'))]
>>> df.select(sf.to_binary(df.e).alias('r')).collect()
[Row(r=b'ABC')]
"""
if format is not None:
return _invoke_function_over_columns("to_binary", col, format)
Expand Down Expand Up @@ -17615,14 +17615,14 @@ def try_to_binary(col: "ColumnOrName", format: Optional["ColumnOrName"] = None)
>>> import pyspark.sql.functions as sf
>>> df = spark.createDataFrame([("abc",)], ["e"])
>>> df.select(sf.try_to_binary(df.e, sf.lit("utf-8")).alias('r')).collect()
[Row(r=bytearray(b'abc'))]
[Row(r=b'abc')]

Example 2: Convert string to a timestamp without encoding specified

>>> import pyspark.sql.functions as sf
>>> df = spark.createDataFrame([("414243",)], ["e"])
>>> df.select(sf.try_to_binary(df.e).alias('r')).collect()
[Row(r=bytearray(b'ABC'))]
[Row(r=b'ABC')]

Example 3: Converion failure results in NULL when ANSI mode is on

Expand Down
24 changes: 20 additions & 4 deletions python/pyspark/sql/pandas/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,11 +350,12 @@ class ArrowStreamPandasSerializer(ArrowStreamSerializer):
This has performance penalties.
"""

def __init__(self, timezone, safecheck, int_to_decimal_coercion_enabled):
def __init__(self, timezone, safecheck, int_to_decimal_coercion_enabled, binary_as_bytes=None):
super(ArrowStreamPandasSerializer, self).__init__()
self._timezone = timezone
self._safecheck = safecheck
self._int_to_decimal_coercion_enabled = int_to_decimal_coercion_enabled
self._binary_as_bytes = binary_as_bytes
self._converter_cache = {}

@staticmethod
Expand Down Expand Up @@ -583,9 +584,10 @@ def __init__(
arrow_cast=False,
input_types=None,
int_to_decimal_coercion_enabled=False,
binary_as_bytes=None,
):
super(ArrowStreamPandasUDFSerializer, self).__init__(
timezone, safecheck, int_to_decimal_coercion_enabled
timezone, safecheck, int_to_decimal_coercion_enabled, binary_as_bytes
)
self._assign_cols_by_name = assign_cols_by_name
self._df_for_struct = df_for_struct
Expand Down Expand Up @@ -782,12 +784,14 @@ def __init__(
safecheck,
assign_cols_by_name,
arrow_cast,
binary_as_bytes=None,
):
super(ArrowStreamArrowUDFSerializer, self).__init__()
self._timezone = timezone
self._safecheck = safecheck
self._assign_cols_by_name = assign_cols_by_name
self._arrow_cast = arrow_cast
self._binary_as_bytes = binary_as_bytes

def _create_array(self, arr, arrow_type, arrow_cast):
import pyarrow as pa
Expand Down Expand Up @@ -862,12 +866,14 @@ def __init__(
safecheck,
input_types,
int_to_decimal_coercion_enabled=False,
binary_as_bytes=None,
):
super().__init__(
timezone=timezone,
safecheck=safecheck,
assign_cols_by_name=False,
arrow_cast=True,
binary_as_bytes=binary_as_bytes,
)
self._input_types = input_types
self._int_to_decimal_coercion_enabled = int_to_decimal_coercion_enabled
Expand All @@ -887,7 +893,9 @@ def load_stream(self, stream):
List of columns containing list of Python values.
"""
converters = [
ArrowTableToRowsConversion._create_converter(dt, none_on_identity=True)
ArrowTableToRowsConversion._create_converter(
dt, none_on_identity=True, binary_as_bytes=self._binary_as_bytes
)
for dt in self._input_types
]

Expand Down Expand Up @@ -949,7 +957,14 @@ class ArrowStreamPandasUDTFSerializer(ArrowStreamPandasUDFSerializer):
Serializer used by Python worker to evaluate Arrow-optimized Python UDTFs.
"""

def __init__(self, timezone, safecheck, input_types, int_to_decimal_coercion_enabled):
def __init__(
self,
timezone,
safecheck,
input_types,
int_to_decimal_coercion_enabled,
binary_as_bytes=None,
):
super(ArrowStreamPandasUDTFSerializer, self).__init__(
timezone=timezone,
safecheck=safecheck,
Expand All @@ -972,6 +987,7 @@ def __init__(self, timezone, safecheck, input_types, int_to_decimal_coercion_ena
input_types=input_types,
# Enable additional coercions for UDTF serialization
int_to_decimal_coercion_enabled=int_to_decimal_coercion_enabled,
binary_as_bytes=binary_as_bytes,
)
self._converter_map = dict()

Expand Down
Loading