Skip to content

Commit a135d26

Browse files
project-defiantSzymon Szyszkowski
andauthored
fix(safe_array_union): allow for sorting nested structs (#793)
* fix: remove study_index_path from coloc step * fix(safe_array_union): sort struct fields in array --------- Co-authored-by: Szymon Szyszkowski <ss60@mib117351s.internal.sanger.ac.uk>
1 parent 9f83329 commit a135d26

File tree

3 files changed

+98
-5
lines changed

3 files changed

+98
-5
lines changed

src/gentropy/common/spark_helpers.py

Lines changed: 86 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -614,14 +614,21 @@ def rename_all_columns(df: DataFrame, prefix: str) -> DataFrame:
614614
)
615615

616616

617-
def safe_array_union(a: Column, b: Column) -> Column:
617+
def safe_array_union(
618+
a: Column, b: Column, fields_order: list[str] | None = None
619+
) -> Column:
618620
"""Merge the content of two optional columns.
619621
620-
The function assumes the array columns have the same schema. Otherwise, the function will fail.
622+
The function assumes the array columns have the same schema.
623+
If the `fields_order` is passed, the function assumes that it deals with array of structs and sorts the nested
624+
struct fields by the provided `fields_order` before conducting array_merge.
625+
If the `fields_order` is not passed and both columns are <array<struct<...>> type then function assumes struct fields have the same order,
626+
otherwise the function will raise an AnalysisException.
621627
622628
Args:
623629
a (Column): One optional array column.
624630
b (Column): The other optional array column.
631+
fields_order (list[str] | None): The order of the fields in the struct. Defaults to None.
625632
626633
Returns:
627634
Column: array column with merged content.
@@ -644,12 +651,89 @@ def safe_array_union(a: Column, b: Column) -> Column:
644651
| null|
645652
+------+
646653
<BLANKLINE>
654+
>>> schema="arr2: array<struct<b:int,a:string>>, arr: array<struct<a:string,b:int>>"
655+
>>> data = [([(1,"a",), (2, "c")],[("a", 1,)]),]
656+
>>> df = spark.createDataFrame(data=data, schema=schema)
657+
>>> df.select(safe_array_union(f.col("arr"), f.col("arr2"), fields_order=["a", "b"]).alias("merged")).show()
658+
+----------------+
659+
| merged|
660+
+----------------+
661+
|[{a, 1}, {c, 2}]|
662+
+----------------+
663+
<BLANKLINE>
664+
>>> schema="arr2: array<struct<b:int,a:string>>, arr: array<struct<a:string,b:int>>"
665+
>>> data = [([(1,"a",), (2, "c")],[("a", 1,)]),]
666+
>>> df = spark.createDataFrame(data=data, schema=schema)
667+
>>> df.select(safe_array_union(f.col("arr"), f.col("arr2")).alias("merged")).show() # doctest: +IGNORE_EXCEPTION_DETAIL
668+
Traceback (most recent call last):
669+
pyspark.sql.utils.AnalysisException: ...
647670
"""
671+
if fields_order:
672+
# sort the nested struct fields by the provided order
673+
a = sort_array_struct_by_columns(a, fields_order)
674+
b = sort_array_struct_by_columns(b, fields_order)
648675
return f.when(a.isNotNull() & b.isNotNull(), f.array_union(a, b)).otherwise(
649676
f.coalesce(a, b)
650677
)
651678

652679

680+
681+
def sort_array_struct_by_columns(column: Column, fields_order: list[str]) -> Column:
682+
"""Sort nested struct fields by provided fields order.
683+
684+
Args:
685+
column (Column): Column with array of structs.
686+
fields_order (list[str]): List of field names to sort by.
687+
688+
Returns:
689+
Column: Sorted column.
690+
691+
Examples:
692+
>>> schema="arr: array<struct<b:int,a:string>>"
693+
>>> data = [([(1,"a",), (2, "c")],)]
694+
>>> fields_order = ["a", "b"]
695+
>>> df = spark.createDataFrame(data=data, schema=schema)
696+
>>> df.select(sort_array_struct_by_columns(f.col("arr"), fields_order).alias("sorted")).show()
697+
+----------------+
698+
| sorted|
699+
+----------------+
700+
|[{c, 2}, {a, 1}]|
701+
+----------------+
702+
<BLANKLINE>
703+
"""
704+
column_name = extract_column_name(column)
705+
fields_order_expr = ", ".join([f"x.{field}" for field in fields_order])
706+
return f.expr(
707+
f"sort_array(transform({column_name}, x -> struct({fields_order_expr})), False)"
708+
).alias(column_name)
709+
710+
711+
def extract_column_name(column: Column) -> str:
712+
"""Extract column name from a column expression.
713+
714+
Args:
715+
column (Column): Column expression.
716+
717+
Returns:
718+
str: Column name.
719+
720+
Raises:
721+
ValueError: If the column name cannot be extracted.
722+
723+
Examples:
724+
>>> extract_column_name(f.col('col1'))
725+
'col1'
726+
>>> extract_column_name(f.sort_array(f.col('col1')))
727+
'sort_array(col1, true)'
728+
"""
729+
pattern = re.compile("^Column<'(?P<name>.*)'>?")
730+
731+
_match = pattern.search(str(column))
732+
if not _match:
733+
raise ValueError(f"Cannot extract column name from {column}")
734+
return _match.group("name")
735+
736+
653737
def create_empty_column_if_not_exists(
654738
col_name: str, col_schema: t.DataType = t.NullType()
655739
) -> Column:

src/gentropy/dataset/variant_index.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,11 @@
66
from typing import TYPE_CHECKING
77

88
import pyspark.sql.functions as f
9+
import pyspark.sql.types as t
910

1011
from gentropy.common.schemas import parse_spark_schema
1112
from gentropy.common.spark_helpers import (
13+
get_nested_struct_schema,
1214
get_record_with_maximum_value,
1315
normalise_column,
1416
rename_all_columns,
@@ -131,6 +133,7 @@ def add_annotation(
131133
# Prefix for renaming columns:
132134
prefix = "annotation_"
133135

136+
134137
# Generate select expressions that to merge and import columns from annotation:
135138
select_expressions = []
136139

@@ -141,10 +144,17 @@ def add_annotation(
141144
# If an annotation column can be found in both datasets:
142145
if (column in self.df.columns) and (column in annotation_source.df.columns):
143146
# Arrays are merged:
144-
if "ArrayType" in field.dataType.__str__():
147+
if isinstance(field.dataType, t.ArrayType):
148+
fields_order = None
149+
if isinstance(field.dataType.elementType, t.StructType):
150+
# Extract the schema of the array to get the order of the fields:
151+
array_schema = [
152+
field for field in VariantIndex.get_schema().fields if field.name == column
153+
][0].dataType
154+
fields_order = get_nested_struct_schema(array_schema).fieldNames()
145155
select_expressions.append(
146156
safe_array_union(
147-
f.col(column), f.col(f"{prefix}{column}")
157+
f.col(column), f.col(f"{prefix}{column}"), fields_order
148158
).alias(column)
149159
)
150160
# Non-array columns are coalesced:

src/gentropy/datasource/open_targets/variants.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,6 @@ def as_vcf_df(
9595
variant_df = variant_df.withColumn(
9696
col, create_empty_column_if_not_exists(col)
9797
)
98-
9998
return (
10099
variant_df.filter(f.col("variantId").isNotNull())
101100
.withColumn(

0 commit comments

Comments
 (0)