Skip to content

Commit

Permalink
Merge branch 'dev' into vh-3448
Browse files Browse the repository at this point in the history
  • Loading branch information
DSuveges authored Sep 30, 2024
2 parents caea96e + 8b253a5 commit bd0ed41
Show file tree
Hide file tree
Showing 13 changed files with 659 additions and 125 deletions.
224 changes: 185 additions & 39 deletions src/gentropy/common/schemas.py
Original file line number Diff line number Diff line change
@@ -1,66 +1,212 @@
"""Methods for handling schemas."""

from __future__ import annotations

import importlib.resources as pkg_resources
import json
from collections import namedtuple
from typing import Any
from collections import defaultdict

import pyspark.sql.types as t
from pyspark.sql.types import ArrayType, StructType

from gentropy.assets import schemas


def parse_spark_schema(schema_json: str) -> t.StructType:
class SchemaValidationError(Exception):
"""This exception is raised when a schema validation fails."""

def __init__(
self: SchemaValidationError, message: str, errors: defaultdict[str, list[str]]
) -> None:
"""Initialize the SchemaValidationError.
Args:
message (str): The message to be displayed.
errors (defaultdict[str, list[str]]): The collection of observed discrepancies
"""
super().__init__(message)
self.message = message # Explicitly set the message attribute
self.errors = errors

def __str__(self: SchemaValidationError) -> str:
"""Return a string representation of the exception.
Returns:
str: The string representation of the exception.
"""
stringified_errors = "\n ".join(
[f'{k}: {",".join(v)}' for k, v in self.errors.items()]
)
return f"{self.message}\nErrors:\n {stringified_errors}"


def parse_spark_schema(schema_json: str) -> StructType:
"""Parse Spark schema from JSON.
Args:
schema_json (str): JSON filename containing spark schema in the schemas package
Returns:
t.StructType: Spark schema
StructType: Spark schema
"""
core_schema = json.loads(
pkg_resources.read_text(schemas, schema_json, encoding="utf-8")
)
return t.StructType.fromJson(core_schema)
return StructType.fromJson(core_schema)


def flatten_schema(schema: t.StructType, prefix: str = "") -> list[Any]:
"""It takes a Spark schema and returns a list of all fields in the schema once flattened.
def compare_array_schemas(
observed_schema: ArrayType,
expected_schema: ArrayType,
parent_field_name: str | None = None,
schema_issues: defaultdict[str, list[str]] | None = None,
) -> defaultdict[str, list[str]]:
"""Compare two array schemas.
The comparison is done recursively, so nested structs are also compared.
Args:
schema (t.StructType): The schema of the dataframe
prefix (str): The prefix to prepend to the field names. Defaults to "".
observed_schema (ArrayType): The observed schema.
expected_schema (ArrayType): The expected schema.
parent_field_name (str | None): The parent field name. Defaults to None.
schema_issues (defaultdict[str, list[str]] | None): The schema issues. Defaults to None.
Returns:
list[Any]: A list of all the columns in the dataframe.
Examples:
>>> from pyspark.sql.types import ArrayType, StringType, StructField, StructType
>>> schema = StructType(
... [
... StructField("studyLocusId", StringType(), False),
... StructField("locus", ArrayType(StructType([StructField("variantId", StringType(), False)])), False)
... ]
... )
>>> df = spark.createDataFrame([("A", [{"variantId": "varA"}]), ("B", [{"variantId": "varB"}])], schema)
>>> flatten_schema(df.schema)
[Field(name='studyLocusId', dataType=StringType()), Field(name='locus', dataType=ArrayType(StructType([]), True)), Field(name='locus.variantId', dataType=StringType())]
defaultdict[str, list[str]]: The schema issues.
"""
Field = namedtuple("Field", ["name", "dataType"])
fields = []
for field in schema.fields:
name = f"{prefix}.{field.name}" if prefix else field.name
dtype = field.dataType
if isinstance(dtype, t.StructType):
fields.append(Field(name, t.ArrayType(t.StructType())))
fields += flatten_schema(dtype, prefix=name)
elif isinstance(dtype, t.ArrayType) and isinstance(
dtype.elementType, t.StructType
):
fields.append(Field(name, t.ArrayType(t.StructType())))
fields += flatten_schema(dtype.elementType, prefix=name)
else:
fields.append(Field(name, dtype))
return fields
# Create default values if not provided:
if schema_issues is None:
schema_issues = defaultdict(list)

if parent_field_name is None:
parent_field_name = ""

observed_type = observed_schema.elementType.typeName()
expected_type = expected_schema.elementType.typeName()

# If element types are not matching, no further tests are needed:
if observed_type != expected_type:
schema_issues["columns_with_non_matching_type"].append(
f'For column "{parent_field_name}[]" found {observed_type} instead of {expected_type}'
)

# If element type is a struct, resolve nesting:
elif (observed_type == "struct") and (expected_type == "struct"):
schema_issues = compare_struct_schemas(
observed_schema.elementType,
expected_schema.elementType,
f"{parent_field_name}[].",
schema_issues,
)

# If element type is an array, resolve nesting:
elif (observed_type == "array") and (expected_type == "array"):
schema_issues = compare_array_schemas(
observed_schema.elementType,
expected_schema.elementType,
parent_field_name,
schema_issues,
)

return schema_issues


def compare_struct_schemas(
observed_schema: StructType,
expected_schema: StructType,
parent_field_name: str | None = None,
schema_issues: defaultdict[str, list[str]] | None = None,
) -> defaultdict[str, list[str]]:
"""Compare two struct schemas.
The comparison is done recursively, so nested structs are also compared.
Checking logic:
1. Checking for duplicated columns in the observed schema.
2. Checking for missing mandatory columns in the observed schema.
3. Now we know that all mandatory columns are present, we can iterate over the observed schema and compare the types.
4. Flagging unexpected columns in the observed schema.
5. Flagging columns with non-matching types.
6. If a column is a struct -> call compare_struct_schemas
7. If a column is an array -> call compare_array_schemas
8. Return dictionary with issues.
Args:
observed_schema (StructType): The observed schema.
expected_schema (StructType): The expected schema.
parent_field_name (str | None): The parent field name. Defaults to None.
schema_issues (defaultdict[str, list[str]] | None): The schema issues. Defaults to None.
Returns:
defaultdict[str, list[str]]: The schema issues.
"""
# Create default values if not provided:
if schema_issues is None:
schema_issues = defaultdict(list)

if parent_field_name is None:
parent_field_name = ""

# Flagging duplicated columns if present:
if duplicated_columns := list(
{
f"{parent_field_name}{field.name}"
for field in observed_schema
if list(observed_schema).count(field) > 1
}
):
schema_issues["duplicated_columns"] += duplicated_columns

# Testing mandatory fields:
required_fields = [x.name for x in expected_schema if not x.nullable]
if missing_required_fields := [
f"{parent_field_name}{req}"
for req in required_fields
if not any(field.name == req for field in observed_schema)
]:
schema_issues["missing_mandatory_columns"] += missing_required_fields

# Converting schema to dictionaries for easier comparison:
observed_schema_dict = {field.name: field for field in observed_schema}
expected_schema_dict = {field.name: field for field in expected_schema}

# Testing optional fields and types:
for field_name, field in observed_schema_dict.items():
# Testing observed field name, if name is not matched, no further tests are needed:
if field_name not in expected_schema_dict:
schema_issues["unexpected_columns"].append(
f"{parent_field_name}{field_name}"
)
continue

# When we made sure the field is in both schemas, extracting field type information:
observed_type = field.dataType
observed_type_name = field.dataType.typeName()

expected_type = expected_schema_dict[field_name].dataType
expected_type_name = expected_schema_dict[field_name].dataType.typeName()

# Flagging non-matching types if types don't match, jumping to next field:
if observed_type_name != expected_type_name:
schema_issues["columns_with_non_matching_type"].append(
f'For column "{parent_field_name}{field_name}" found {observed_type_name} instead of {expected_type_name}'
)
continue

# If column is a struct, resolve nesting:
if observed_type_name == "struct":
schema_issues = compare_struct_schemas(
observed_type,
expected_type,
f"{parent_field_name}{field_name}.",
schema_issues,
)
# If column is an array, resolve nesting:
elif observed_type_name == "array":
schema_issues = compare_array_schemas(
observed_type,
expected_type,
f"{parent_field_name}{field_name}[]",
schema_issues,
)

return schema_issues
93 changes: 88 additions & 5 deletions src/gentropy/common/spark_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,13 +270,13 @@ def neglog_pvalue_to_mantissa_and_exponent(p_value: Column) -> tuple[Column, Col
+--------+--------------+--------------+
|negLogPv|pValueMantissa|pValueExponent|
+--------+--------------+--------------+
| 4.56| 3.6307805| -5|
| 2109.23| 1.6982436| -2110|
| 4.56| 2.7542286| -5|
| 2109.23| 5.8884363| -2110|
+--------+--------------+--------------+
<BLANKLINE>
"""
exponent: Column = f.ceil(p_value)
mantissa: Column = f.pow(f.lit(10), (p_value - exponent + f.lit(1)))
mantissa: Column = f.pow(f.lit(10), (exponent - p_value))

return (
mantissa.cast(t.FloatType()).alias("pValueMantissa"),
Expand Down Expand Up @@ -614,14 +614,21 @@ def rename_all_columns(df: DataFrame, prefix: str) -> DataFrame:
)


def safe_array_union(a: Column, b: Column) -> Column:
def safe_array_union(
a: Column, b: Column, fields_order: list[str] | None = None
) -> Column:
"""Merge the content of two optional columns.
The function assumes the array columns have the same schema. Otherwise, the function will fail.
The function assumes the array columns have the same schema.
If the `fields_order` is passed, the function assumes that it deals with array of structs and sorts the nested
struct fields by the provided `fields_order` before conducting array_merge.
If the `fields_order` is not passed and both columns are <array<struct<...>> type then function assumes struct fields have the same order,
otherwise the function will raise an AnalysisException.
Args:
a (Column): One optional array column.
b (Column): The other optional array column.
fields_order (list[str] | None): The order of the fields in the struct. Defaults to None.
Returns:
Column: array column with merged content.
Expand All @@ -644,12 +651,88 @@ def safe_array_union(a: Column, b: Column) -> Column:
| null|
+------+
<BLANKLINE>
>>> schema="arr2: array<struct<b:int,a:string>>, arr: array<struct<a:string,b:int>>"
>>> data = [([(1,"a",), (2, "c")],[("a", 1,)]),]
>>> df = spark.createDataFrame(data=data, schema=schema)
>>> df.select(safe_array_union(f.col("arr"), f.col("arr2"), fields_order=["a", "b"]).alias("merged")).show()
+----------------+
| merged|
+----------------+
|[{a, 1}, {c, 2}]|
+----------------+
<BLANKLINE>
>>> schema="arr2: array<struct<b:int,a:string>>, arr: array<struct<a:string,b:int>>"
>>> data = [([(1,"a",), (2, "c")],[("a", 1,)]),]
>>> df = spark.createDataFrame(data=data, schema=schema)
>>> df.select(safe_array_union(f.col("arr"), f.col("arr2")).alias("merged")).show() # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
pyspark.sql.utils.AnalysisException: ...
"""
if fields_order:
# sort the nested struct fields by the provided order
a = sort_array_struct_by_columns(a, fields_order)
b = sort_array_struct_by_columns(b, fields_order)
return f.when(a.isNotNull() & b.isNotNull(), f.array_union(a, b)).otherwise(
f.coalesce(a, b)
)


def sort_array_struct_by_columns(column: Column, fields_order: list[str]) -> Column:
"""Sort nested struct fields by provided fields order.
Args:
column (Column): Column with array of structs.
fields_order (list[str]): List of field names to sort by.
Returns:
Column: Sorted column.
Examples:
>>> schema="arr: array<struct<b:int,a:string>>"
>>> data = [([(1,"a",), (2, "c")],)]
>>> fields_order = ["a", "b"]
>>> df = spark.createDataFrame(data=data, schema=schema)
>>> df.select(sort_array_struct_by_columns(f.col("arr"), fields_order).alias("sorted")).show()
+----------------+
| sorted|
+----------------+
|[{c, 2}, {a, 1}]|
+----------------+
<BLANKLINE>
"""
column_name = extract_column_name(column)
fields_order_expr = ", ".join([f"x.{field}" for field in fields_order])
return f.expr(
f"sort_array(transform({column_name}, x -> struct({fields_order_expr})), False)"
).alias(column_name)


def extract_column_name(column: Column) -> str:
"""Extract column name from a column expression.
Args:
column (Column): Column expression.
Returns:
str: Column name.
Raises:
ValueError: If the column name cannot be extracted.
Examples:
>>> extract_column_name(f.col('col1'))
'col1'
>>> extract_column_name(f.sort_array(f.col('col1')))
'sort_array(col1, true)'
"""
pattern = re.compile("^Column<'(?P<name>.*)'>?")

_match = pattern.search(str(column))
if not _match:
raise ValueError(f"Cannot extract column name from {column}")
return _match.group("name")


def create_empty_column_if_not_exists(
col_name: str, col_schema: t.DataType = t.NullType()
) -> Column:
Expand Down
Loading

0 comments on commit bd0ed41

Please sign in to comment.