@@ -614,14 +614,21 @@ def rename_all_columns(df: DataFrame, prefix: str) -> DataFrame:
614
614
)
615
615
616
616
617
- def safe_array_union (a : Column , b : Column ) -> Column :
617
+ def safe_array_union (
618
+ a : Column , b : Column , fields_order : list [str ] | None = None
619
+ ) -> Column :
618
620
"""Merge the content of two optional columns.
619
621
620
- The function assumes the array columns have the same schema. Otherwise, the function will fail.
622
+ The function assumes the array columns have the same schema.
623
+ If the `fields_order` is passed, the function assumes that it deals with array of structs and sorts the nested
624
+ struct fields by the provided `fields_order` before conducting array_merge.
625
+ If the `fields_order` is not passed and both columns are <array<struct<...>> type then function assumes struct fields have the same order,
626
+ otherwise the function will raise an AnalysisException.
621
627
622
628
Args:
623
629
a (Column): One optional array column.
624
630
b (Column): The other optional array column.
631
+ fields_order (list[str] | None): The order of the fields in the struct. Defaults to None.
625
632
626
633
Returns:
627
634
Column: array column with merged content.
@@ -644,12 +651,89 @@ def safe_array_union(a: Column, b: Column) -> Column:
644
651
| null|
645
652
+------+
646
653
<BLANKLINE>
654
+ >>> schema="arr2: array<struct<b:int,a:string>>, arr: array<struct<a:string,b:int>>"
655
+ >>> data = [([(1,"a",), (2, "c")],[("a", 1,)]),]
656
+ >>> df = spark.createDataFrame(data=data, schema=schema)
657
+ >>> df.select(safe_array_union(f.col("arr"), f.col("arr2"), fields_order=["a", "b"]).alias("merged")).show()
658
+ +----------------+
659
+ | merged|
660
+ +----------------+
661
+ |[{a, 1}, {c, 2}]|
662
+ +----------------+
663
+ <BLANKLINE>
664
+ >>> schema="arr2: array<struct<b:int,a:string>>, arr: array<struct<a:string,b:int>>"
665
+ >>> data = [([(1,"a",), (2, "c")],[("a", 1,)]),]
666
+ >>> df = spark.createDataFrame(data=data, schema=schema)
667
+ >>> df.select(safe_array_union(f.col("arr"), f.col("arr2")).alias("merged")).show() # doctest: +IGNORE_EXCEPTION_DETAIL
668
+ Traceback (most recent call last):
669
+ pyspark.sql.utils.AnalysisException: ...
647
670
"""
671
+ if fields_order :
672
+ # sort the nested struct fields by the provided order
673
+ a = sort_array_struct_by_columns (a , fields_order )
674
+ b = sort_array_struct_by_columns (b , fields_order )
648
675
return f .when (a .isNotNull () & b .isNotNull (), f .array_union (a , b )).otherwise (
649
676
f .coalesce (a , b )
650
677
)
651
678
652
679
680
+
681
+ def sort_array_struct_by_columns (column : Column , fields_order : list [str ]) -> Column :
682
+ """Sort nested struct fields by provided fields order.
683
+
684
+ Args:
685
+ column (Column): Column with array of structs.
686
+ fields_order (list[str]): List of field names to sort by.
687
+
688
+ Returns:
689
+ Column: Sorted column.
690
+
691
+ Examples:
692
+ >>> schema="arr: array<struct<b:int,a:string>>"
693
+ >>> data = [([(1,"a",), (2, "c")],)]
694
+ >>> fields_order = ["a", "b"]
695
+ >>> df = spark.createDataFrame(data=data, schema=schema)
696
+ >>> df.select(sort_array_struct_by_columns(f.col("arr"), fields_order).alias("sorted")).show()
697
+ +----------------+
698
+ | sorted|
699
+ +----------------+
700
+ |[{c, 2}, {a, 1}]|
701
+ +----------------+
702
+ <BLANKLINE>
703
+ """
704
+ column_name = extract_column_name (column )
705
+ fields_order_expr = ", " .join ([f"x.{ field } " for field in fields_order ])
706
+ return f .expr (
707
+ f"sort_array(transform({ column_name } , x -> struct({ fields_order_expr } )), False)"
708
+ ).alias (column_name )
709
+
710
+
711
+ def extract_column_name (column : Column ) -> str :
712
+ """Extract column name from a column expression.
713
+
714
+ Args:
715
+ column (Column): Column expression.
716
+
717
+ Returns:
718
+ str: Column name.
719
+
720
+ Raises:
721
+ ValueError: If the column name cannot be extracted.
722
+
723
+ Examples:
724
+ >>> extract_column_name(f.col('col1'))
725
+ 'col1'
726
+ >>> extract_column_name(f.sort_array(f.col('col1')))
727
+ 'sort_array(col1, true)'
728
+ """
729
+ pattern = re .compile ("^Column<'(?P<name>.*)'>?" )
730
+
731
+ _match = pattern .search (str (column ))
732
+ if not _match :
733
+ raise ValueError (f"Cannot extract column name from { column } " )
734
+ return _match .group ("name" )
735
+
736
+
653
737
def create_empty_column_if_not_exists (
654
738
col_name : str , col_schema : t .DataType = t .NullType ()
655
739
) -> Column :
0 commit comments