46
46
import pyarrow as pa
47
47
48
48
49
+ def _should_use_bytes_for_binary (binary_as_bytes : Optional [bool ] = None ) -> bool :
50
+ """Check if BINARY type should be converted to bytes instead of bytearray."""
51
+ if binary_as_bytes is not None :
52
+ return binary_as_bytes
53
+
54
+ from pyspark .sql import SparkSession
55
+
56
+ spark = SparkSession .getActiveSession ()
57
+ if spark is not None :
58
+ v = spark .conf .get ("spark.sql.execution.pyspark.binaryAsBytes.enabled" , "true" )
59
+ return str (v ).lower () == "true"
60
+
61
+ return True
62
+
63
+
49
64
class LocalDataToArrowConversion :
50
65
"""
51
66
Conversion from local data (except pandas DataFrame and numpy ndarray) to Arrow.
@@ -518,13 +533,16 @@ def _create_converter(dataType: DataType) -> Callable:
518
533
@overload
519
534
@staticmethod
520
535
def _create_converter (
521
- dataType : DataType , * , none_on_identity : bool = True
536
+ dataType : DataType , * , none_on_identity : bool = True , binary_as_bytes : Optional [ bool ] = None
522
537
) -> Optional [Callable ]:
523
538
pass
524
539
525
540
@staticmethod
526
541
def _create_converter (
527
- dataType : DataType , * , none_on_identity : bool = False
542
+ dataType : DataType ,
543
+ * ,
544
+ none_on_identity : bool = False ,
545
+ binary_as_bytes : Optional [bool ] = None ,
528
546
) -> Optional [Callable ]:
529
547
assert dataType is not None and isinstance (dataType , DataType )
530
548
@@ -542,7 +560,9 @@ def _create_converter(
542
560
dedup_field_names = _dedup_names (field_names )
543
561
544
562
field_convs = [
545
- ArrowTableToRowsConversion ._create_converter (f .dataType , none_on_identity = True )
563
+ ArrowTableToRowsConversion ._create_converter (
564
+ f .dataType , none_on_identity = True , binary_as_bytes = binary_as_bytes
565
+ )
546
566
for f in dataType .fields
547
567
]
548
568
@@ -564,7 +584,7 @@ def convert_struct(value: Any) -> Any:
564
584
565
585
elif isinstance (dataType , ArrayType ):
566
586
element_conv = ArrowTableToRowsConversion ._create_converter (
567
- dataType .elementType , none_on_identity = True
587
+ dataType .elementType , none_on_identity = True , binary_as_bytes = binary_as_bytes
568
588
)
569
589
570
590
if element_conv is None :
@@ -589,10 +609,10 @@ def convert_array(value: Any) -> Any:
589
609
590
610
elif isinstance (dataType , MapType ):
591
611
key_conv = ArrowTableToRowsConversion ._create_converter (
592
- dataType .keyType , none_on_identity = True
612
+ dataType .keyType , none_on_identity = True , binary_as_bytes = binary_as_bytes
593
613
)
594
614
value_conv = ArrowTableToRowsConversion ._create_converter (
595
- dataType .valueType , none_on_identity = True
615
+ dataType .valueType , none_on_identity = True , binary_as_bytes = binary_as_bytes
596
616
)
597
617
598
618
if key_conv is None :
@@ -646,7 +666,10 @@ def convert_binary(value: Any) -> Any:
646
666
return None
647
667
else :
648
668
assert isinstance (value , bytes )
649
- return bytearray (value )
669
+ if _should_use_bytes_for_binary (binary_as_bytes ):
670
+ return value
671
+ else :
672
+ return bytearray (value )
650
673
651
674
return convert_binary
652
675
@@ -676,7 +699,7 @@ def convert_timestample_ntz(value: Any) -> Any:
676
699
udt : UserDefinedType = dataType
677
700
678
701
conv = ArrowTableToRowsConversion ._create_converter (
679
- udt .sqlType (), none_on_identity = True
702
+ udt .sqlType (), none_on_identity = True , binary_as_bytes = binary_as_bytes
680
703
)
681
704
682
705
if conv is None :
@@ -722,20 +745,28 @@ def convert_variant(value: Any) -> Any:
722
745
@overload
723
746
@staticmethod
724
747
def convert ( # type: ignore[overload-overlap]
725
- table : "pa.Table" , schema : StructType
748
+ table : "pa.Table" , schema : StructType , * , binary_as_bytes : Optional [ bool ] = None
726
749
) -> List [Row ]:
727
750
pass
728
751
729
752
@overload
730
753
@staticmethod
731
754
def convert (
732
- table : "pa.Table" , schema : StructType , * , return_as_tuples : bool = True
755
+ table : "pa.Table" ,
756
+ schema : StructType ,
757
+ * ,
758
+ return_as_tuples : bool = True ,
759
+ binary_as_bytes : Optional [bool ] = None ,
733
760
) -> List [tuple ]:
734
761
pass
735
762
736
763
@staticmethod # type: ignore[misc]
737
764
def convert (
738
- table : "pa.Table" , schema : StructType , * , return_as_tuples : bool = False
765
+ table : "pa.Table" ,
766
+ schema : StructType ,
767
+ * ,
768
+ return_as_tuples : bool = False ,
769
+ binary_as_bytes : Optional [bool ] = None ,
739
770
) -> List [Union [Row , tuple ]]:
740
771
require_minimum_pyarrow_version ()
741
772
import pyarrow as pa
@@ -748,7 +779,9 @@ def convert(
748
779
749
780
if len (fields ) > 0 :
750
781
field_converters = [
751
- ArrowTableToRowsConversion ._create_converter (f .dataType , none_on_identity = True )
782
+ ArrowTableToRowsConversion ._create_converter (
783
+ f .dataType , none_on_identity = True , binary_as_bytes = binary_as_bytes
784
+ )
752
785
for f in schema .fields
753
786
]
754
787
0 commit comments