Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(vep_parser): use nested schema for insilico predictors #789

Merged
merged 4 commits into from
Sep 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 50 additions & 2 deletions src/gentropy/common/spark_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,8 @@ def order_array_of_structs_by_two_fields(
"""Sort array of structs by a field in descending order and by an other field in an ascending order.

This function doesn't deal with null values, assumes the sort columns are not nullable.
The sorting function compares the descending_column first, in case when two values from descending_column are equal
it compares the ascending_column. When values in both columns are equal, the rows order is preserved.

Args:
array_name (str): Column name with array of structs
Expand All @@ -406,6 +408,20 @@ def order_array_of_structs_by_two_fields(
|[{1.0, 45, First}, {1.0, 125, Second}, {0.5, 232, Third}, {0.5, 233, Fourth}]|
+-----------------------------------------------------------------------------+
<BLANKLINE>
>>> data = [(1.0, 45, 'First'), (1.0, 45, 'Second'), (0.5, 233, 'Fourth'), (1.0, 125, 'Third'),]
>>> (
... spark.createDataFrame(data, ['col1', 'col2', 'ranking'])
... .groupBy(f.lit('c'))
... .agg(f.collect_list(f.struct('col1','col2', 'ranking')).alias('list'))
... .select(order_array_of_structs_by_two_fields('list', 'col1', 'col2').alias('sorted_list'))
... .show(truncate=False)
... )
+----------------------------------------------------------------------------+
|sorted_list |
+----------------------------------------------------------------------------+
|[{1.0, 45, First}, {1.0, 45, Second}, {1.0, 125, Third}, {0.5, 233, Fourth}]|
+----------------------------------------------------------------------------+
<BLANKLINE>
"""
return f.expr(
f"""
Expand All @@ -425,6 +441,7 @@ def order_array_of_structs_by_two_fields(
when left.{descending_column} > right.{descending_column} then -1
when left.{descending_column} == right.{descending_column} and left.{ascending_column} > right.{ascending_column} then 1
when left.{descending_column} == right.{descending_column} and left.{ascending_column} < right.{ascending_column} then -1
when left.{ascending_column} == right.{ascending_column} and left.{descending_column} == right.{descending_column} then 0
end)
"""
)
Expand Down Expand Up @@ -525,7 +542,7 @@ def get_value_from_row(row: Row, column: str) -> Any:


def enforce_schema(
expected_schema: t.StructType,
expected_schema: t.ArrayType | t.StructType | Column | str,
project-defiant marked this conversation as resolved.
Show resolved Hide resolved
) -> Callable[..., Any]:
"""A function to enforce the schema of a function output follows expectation.

Expand All @@ -541,7 +558,7 @@ def my_function() -> t.StructType:
return ...

Args:
expected_schema (t.StructType): The expected schema of the output.
expected_schema (t.ArrayType | t.StructType | Column | str): The expected schema of the output.

Returns:
Callable[..., Any]: A decorator function.
Expand Down Expand Up @@ -687,3 +704,34 @@ def get_standard_error_from_confidence_interval(lower: Column, upper: Column) ->
<BLANKLINE>
"""
return (upper - lower) / (2 * 1.96)


def get_nested_struct_schema(dtype: t.DataType) -> t.StructType:
"""Get the bottom StructType from a nested ArrayType type.

Args:
dtype (t.DataType): The nested data structure.

Returns:
t.StructType: The nested struct schema.

Raises:
TypeError: If the input data type is not a nested struct.

Examples:
>>> get_nested_struct_schema(t.ArrayType(t.StructType([t.StructField('a', t.StringType())])))
StructType([StructField('a', StringType(), True)])

>>> get_nested_struct_schema(t.ArrayType(t.ArrayType(t.StructType([t.StructField("a", t.StringType())]))))
StructType([StructField('a', StringType(), True)])
"""
if isinstance(dtype, t.StructField):
dtype = dtype.dataType

match dtype:
case t.StructType(fields=_):
return dtype
case t.ArrayType(elementType=dtype):
return get_nested_struct_schema(dtype)
case _:
raise TypeError("The input data type must be a nested struct.")
23 changes: 13 additions & 10 deletions src/gentropy/datasource/ensembl/vep_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from gentropy.common.schemas import parse_spark_schema
from gentropy.common.spark_helpers import (
enforce_schema,
get_nested_struct_schema,
map_column_by_dictionary,
order_array_of_structs_by_field,
order_array_of_structs_by_two_fields,
Expand All @@ -26,14 +27,16 @@

class VariantEffectPredictorParser:
"""Collection of methods to parse VEP output in json format."""
# NOTE: Due to the fact that the comparison of the xrefs is done om the base of rsids
# if the field `colocalised_variants` have multiple rsids, this extracting xrefs will result in
# an array of xref structs, rather then the struct itself.

# Schema description of the dbXref object:
DBXREF_SCHEMA = VariantIndex.get_schema()["dbXrefs"].dataType

# Schema description of the in silico predictor object:
IN_SILICO_PREDICTOR_SCHEMA = VariantIndex.get_schema()[
"inSilicoPredictors"
].dataType
IN_SILICO_PREDICTOR_SCHEMA = get_nested_struct_schema(
VariantIndex.get_schema()["inSilicoPredictors"]
)

# Schema for the allele frequency column:
ALLELE_FREQUENCY_SCHEMA = VariantIndex.get_schema()["alleleFrequencies"].dataType
Expand Down Expand Up @@ -350,12 +353,12 @@ def _get_max_alpha_missense(transcripts: Column) -> Column:
... .select(VariantEffectPredictorParser._get_max_alpha_missense(f.col('transcripts')).alias('am'))
... .show(truncate=False)
... )
+------------------------------------------------------+
|am |
+------------------------------------------------------+
|[{max alpha missense, assessment 1, 0.4, null, gene1}]|
|[{max alpha missense, null, null, null, gene1}] |
+------------------------------------------------------+
+----------------------------------------------------+
|am |
+----------------------------------------------------+
|{max alpha missense, assessment 1, 0.4, null, gene1}|
|{max alpha missense, null, null, null, gene1} |
+----------------------------------------------------+
<BLANKLINE>
"""
return f.transform(
Expand Down
Loading
Loading