Skip to content

Commit

Permalink
Merge branch 'dev' into ds_3545-schema-validation-misses-nested-arrays
Browse files Browse the repository at this point in the history
  • Loading branch information
DSuveges authored Sep 27, 2024
2 parents 6c6918f + a135d26 commit 77eb0b8
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 5 deletions.
88 changes: 86 additions & 2 deletions src/gentropy/common/spark_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -614,14 +614,21 @@ def rename_all_columns(df: DataFrame, prefix: str) -> DataFrame:
)


def safe_array_union(a: Column, b: Column) -> Column:
def safe_array_union(
a: Column, b: Column, fields_order: list[str] | None = None
) -> Column:
"""Merge the content of two optional columns.
The function assumes the array columns have the same schema. Otherwise, the function will fail.
The function assumes the array columns have the same schema.
If the `fields_order` is passed, the function assumes that it deals with array of structs and sorts the nested
struct fields by the provided `fields_order` before conducting array_merge.
If the `fields_order` is not passed and both columns are <array<struct<...>> type then function assumes struct fields have the same order,
otherwise the function will raise an AnalysisException.
Args:
a (Column): One optional array column.
b (Column): The other optional array column.
fields_order (list[str] | None): The order of the fields in the struct. Defaults to None.
Returns:
Column: array column with merged content.
Expand All @@ -644,12 +651,89 @@ def safe_array_union(a: Column, b: Column) -> Column:
| null|
+------+
<BLANKLINE>
>>> schema="arr2: array<struct<b:int,a:string>>, arr: array<struct<a:string,b:int>>"
>>> data = [([(1,"a",), (2, "c")],[("a", 1,)]),]
>>> df = spark.createDataFrame(data=data, schema=schema)
>>> df.select(safe_array_union(f.col("arr"), f.col("arr2"), fields_order=["a", "b"]).alias("merged")).show()
+----------------+
| merged|
+----------------+
|[{a, 1}, {c, 2}]|
+----------------+
<BLANKLINE>
>>> schema="arr2: array<struct<b:int,a:string>>, arr: array<struct<a:string,b:int>>"
>>> data = [([(1,"a",), (2, "c")],[("a", 1,)]),]
>>> df = spark.createDataFrame(data=data, schema=schema)
>>> df.select(safe_array_union(f.col("arr"), f.col("arr2")).alias("merged")).show() # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
pyspark.sql.utils.AnalysisException: ...
"""
if fields_order:
# sort the nested struct fields by the provided order
a = sort_array_struct_by_columns(a, fields_order)
b = sort_array_struct_by_columns(b, fields_order)
return f.when(a.isNotNull() & b.isNotNull(), f.array_union(a, b)).otherwise(
f.coalesce(a, b)
)



def sort_array_struct_by_columns(column: Column, fields_order: list[str]) -> Column:
"""Sort nested struct fields by provided fields order.
Args:
column (Column): Column with array of structs.
fields_order (list[str]): List of field names to sort by.
Returns:
Column: Sorted column.
Examples:
>>> schema="arr: array<struct<b:int,a:string>>"
>>> data = [([(1,"a",), (2, "c")],)]
>>> fields_order = ["a", "b"]
>>> df = spark.createDataFrame(data=data, schema=schema)
>>> df.select(sort_array_struct_by_columns(f.col("arr"), fields_order).alias("sorted")).show()
+----------------+
| sorted|
+----------------+
|[{c, 2}, {a, 1}]|
+----------------+
<BLANKLINE>
"""
column_name = extract_column_name(column)
fields_order_expr = ", ".join([f"x.{field}" for field in fields_order])
return f.expr(
f"sort_array(transform({column_name}, x -> struct({fields_order_expr})), False)"
).alias(column_name)


def extract_column_name(column: Column) -> str:
"""Extract column name from a column expression.
Args:
column (Column): Column expression.
Returns:
str: Column name.
Raises:
ValueError: If the column name cannot be extracted.
Examples:
>>> extract_column_name(f.col('col1'))
'col1'
>>> extract_column_name(f.sort_array(f.col('col1')))
'sort_array(col1, true)'
"""
pattern = re.compile("^Column<'(?P<name>.*)'>?")

_match = pattern.search(str(column))
if not _match:
raise ValueError(f"Cannot extract column name from {column}")
return _match.group("name")


def create_empty_column_if_not_exists(
col_name: str, col_schema: t.DataType = t.NullType()
) -> Column:
Expand Down
14 changes: 12 additions & 2 deletions src/gentropy/dataset/variant_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@
from typing import TYPE_CHECKING

import pyspark.sql.functions as f
import pyspark.sql.types as t

from gentropy.common.schemas import parse_spark_schema
from gentropy.common.spark_helpers import (
get_nested_struct_schema,
get_record_with_maximum_value,
normalise_column,
rename_all_columns,
Expand Down Expand Up @@ -131,6 +133,7 @@ def add_annotation(
# Prefix for renaming columns:
prefix = "annotation_"


# Generate select expressions that to merge and import columns from annotation:
select_expressions = []

Expand All @@ -141,10 +144,17 @@ def add_annotation(
# If an annotation column can be found in both datasets:
if (column in self.df.columns) and (column in annotation_source.df.columns):
# Arrays are merged:
if "ArrayType" in field.dataType.__str__():
if isinstance(field.dataType, t.ArrayType):
fields_order = None
if isinstance(field.dataType.elementType, t.StructType):
# Extract the schema of the array to get the order of the fields:
array_schema = [
field for field in VariantIndex.get_schema().fields if field.name == column
][0].dataType
fields_order = get_nested_struct_schema(array_schema).fieldNames()
select_expressions.append(
safe_array_union(
f.col(column), f.col(f"{prefix}{column}")
f.col(column), f.col(f"{prefix}{column}"), fields_order
).alias(column)
)
# Non-array columns are coalesced:
Expand Down
1 change: 0 additions & 1 deletion src/gentropy/datasource/open_targets/variants.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,6 @@ def as_vcf_df(
variant_df = variant_df.withColumn(
col, create_empty_column_if_not_exists(col)
)

return (
variant_df.filter(f.col("variantId").isNotNull())
.withColumn(
Expand Down

0 comments on commit 77eb0b8

Please sign in to comment.