Skip to content

Commit

Permalink
Merge branch 'dev' into yt_add_fillters_to_susie
Browse files Browse the repository at this point in the history
  • Loading branch information
addramir authored Sep 30, 2024
2 parents 02f8272 + 5c58e58 commit 4d97e66
Show file tree
Hide file tree
Showing 36 changed files with 736 additions and 300 deletions.
4 changes: 2 additions & 2 deletions src/gentropy/assets/schemas/colocalisation.json
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@
{
"name": "leftStudyLocusId",
"nullable": false,
"type": "long",
"type": "string",
"metadata": {}
},
{
"name": "rightStudyLocusId",
"nullable": false,
"type": "long",
"type": "string",
"metadata": {}
},
{
Expand Down
2 changes: 1 addition & 1 deletion src/gentropy/assets/schemas/l2g_feature.json
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"fields": [
{
"name": "studyLocusId",
"type": "long",
"type": "string",
"nullable": false,
"metadata": {}
},
Expand Down
2 changes: 1 addition & 1 deletion src/gentropy/assets/schemas/l2g_gold_standard.json
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"fields": [
{
"name": "studyLocusId",
"type": "long",
"type": "string",
"nullable": false,
"metadata": {}
},
Expand Down
2 changes: 1 addition & 1 deletion src/gentropy/assets/schemas/l2g_predictions.json
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"fields": [
{
"name": "studyLocusId",
"type": "long",
"type": "string",
"nullable": false,
"metadata": {}
},
Expand Down
2 changes: 1 addition & 1 deletion src/gentropy/assets/schemas/study_locus.json
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
"metadata": {},
"name": "studyLocusId",
"nullable": false,
"type": "long"
"type": "string"
},
{
"metadata": {},
Expand Down
4 changes: 2 additions & 2 deletions src/gentropy/assets/schemas/study_locus_overlap.json
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@
"metadata": {},
"name": "leftStudyLocusId",
"nullable": false,
"type": "long"
"type": "string"
},
{
"metadata": {},
"name": "rightStudyLocusId",
"nullable": false,
"type": "long"
"type": "string"
},
{
"metadata": {},
Expand Down
224 changes: 185 additions & 39 deletions src/gentropy/common/schemas.py
Original file line number Diff line number Diff line change
@@ -1,66 +1,212 @@
"""Methods for handling schemas."""

from __future__ import annotations

import importlib.resources as pkg_resources
import json
from collections import namedtuple
from typing import Any
from collections import defaultdict

import pyspark.sql.types as t
from pyspark.sql.types import ArrayType, StructType

from gentropy.assets import schemas


def parse_spark_schema(schema_json: str) -> t.StructType:
class SchemaValidationError(Exception):
"""This exception is raised when a schema validation fails."""

def __init__(
self: SchemaValidationError, message: str, errors: defaultdict[str, list[str]]
) -> None:
"""Initialize the SchemaValidationError.
Args:
message (str): The message to be displayed.
errors (defaultdict[str, list[str]]): The collection of observed discrepancies
"""
super().__init__(message)
self.message = message # Explicitly set the message attribute
self.errors = errors

def __str__(self: SchemaValidationError) -> str:
"""Return a string representation of the exception.
Returns:
str: The string representation of the exception.
"""
stringified_errors = "\n ".join(
[f'{k}: {",".join(v)}' for k, v in self.errors.items()]
)
return f"{self.message}\nErrors:\n {stringified_errors}"


def parse_spark_schema(schema_json: str) -> StructType:
"""Parse Spark schema from JSON.
Args:
schema_json (str): JSON filename containing spark schema in the schemas package
Returns:
t.StructType: Spark schema
StructType: Spark schema
"""
core_schema = json.loads(
pkg_resources.read_text(schemas, schema_json, encoding="utf-8")
)
return t.StructType.fromJson(core_schema)
return StructType.fromJson(core_schema)


def flatten_schema(schema: t.StructType, prefix: str = "") -> list[Any]:
"""It takes a Spark schema and returns a list of all fields in the schema once flattened.
def compare_array_schemas(
observed_schema: ArrayType,
expected_schema: ArrayType,
parent_field_name: str | None = None,
schema_issues: defaultdict[str, list[str]] | None = None,
) -> defaultdict[str, list[str]]:
"""Compare two array schemas.
The comparison is done recursively, so nested structs are also compared.
Args:
schema (t.StructType): The schema of the dataframe
prefix (str): The prefix to prepend to the field names. Defaults to "".
observed_schema (ArrayType): The observed schema.
expected_schema (ArrayType): The expected schema.
parent_field_name (str | None): The parent field name. Defaults to None.
schema_issues (defaultdict[str, list[str]] | None): The schema issues. Defaults to None.
Returns:
list[Any]: A list of all the columns in the dataframe.
Examples:
>>> from pyspark.sql.types import ArrayType, StringType, StructField, StructType
>>> schema = StructType(
... [
... StructField("studyLocusId", StringType(), False),
... StructField("locus", ArrayType(StructType([StructField("variantId", StringType(), False)])), False)
... ]
... )
>>> df = spark.createDataFrame([("A", [{"variantId": "varA"}]), ("B", [{"variantId": "varB"}])], schema)
>>> flatten_schema(df.schema)
[Field(name='studyLocusId', dataType=StringType()), Field(name='locus', dataType=ArrayType(StructType([]), True)), Field(name='locus.variantId', dataType=StringType())]
defaultdict[str, list[str]]: The schema issues.
"""
Field = namedtuple("Field", ["name", "dataType"])
fields = []
for field in schema.fields:
name = f"{prefix}.{field.name}" if prefix else field.name
dtype = field.dataType
if isinstance(dtype, t.StructType):
fields.append(Field(name, t.ArrayType(t.StructType())))
fields += flatten_schema(dtype, prefix=name)
elif isinstance(dtype, t.ArrayType) and isinstance(
dtype.elementType, t.StructType
):
fields.append(Field(name, t.ArrayType(t.StructType())))
fields += flatten_schema(dtype.elementType, prefix=name)
else:
fields.append(Field(name, dtype))
return fields
# Create default values if not provided:
if schema_issues is None:
schema_issues = defaultdict(list)

if parent_field_name is None:
parent_field_name = ""

observed_type = observed_schema.elementType.typeName()
expected_type = expected_schema.elementType.typeName()

# If element types are not matching, no further tests are needed:
if observed_type != expected_type:
schema_issues["columns_with_non_matching_type"].append(
f'For column "{parent_field_name}[]" found {observed_type} instead of {expected_type}'
)

# If element type is a struct, resolve nesting:
elif (observed_type == "struct") and (expected_type == "struct"):
schema_issues = compare_struct_schemas(
observed_schema.elementType,
expected_schema.elementType,
f"{parent_field_name}[].",
schema_issues,
)

# If element type is an array, resolve nesting:
elif (observed_type == "array") and (expected_type == "array"):
schema_issues = compare_array_schemas(
observed_schema.elementType,
expected_schema.elementType,
parent_field_name,
schema_issues,
)

return schema_issues


def compare_struct_schemas(
observed_schema: StructType,
expected_schema: StructType,
parent_field_name: str | None = None,
schema_issues: defaultdict[str, list[str]] | None = None,
) -> defaultdict[str, list[str]]:
"""Compare two struct schemas.
The comparison is done recursively, so nested structs are also compared.
Checking logic:
1. Checking for duplicated columns in the observed schema.
2. Checking for missing mandatory columns in the observed schema.
3. Now we know that all mandatory columns are present, we can iterate over the observed schema and compare the types.
4. Flagging unexpected columns in the observed schema.
5. Flagging columns with non-matching types.
6. If a column is a struct -> call compare_struct_schemas
7. If a column is an array -> call compare_array_schemas
8. Return dictionary with issues.
Args:
observed_schema (StructType): The observed schema.
expected_schema (StructType): The expected schema.
parent_field_name (str | None): The parent field name. Defaults to None.
schema_issues (defaultdict[str, list[str]] | None): The schema issues. Defaults to None.
Returns:
defaultdict[str, list[str]]: The schema issues.
"""
# Create default values if not provided:
if schema_issues is None:
schema_issues = defaultdict(list)

if parent_field_name is None:
parent_field_name = ""

# Flagging duplicated columns if present:
if duplicated_columns := list(
{
f"{parent_field_name}{field.name}"
for field in observed_schema
if list(observed_schema).count(field) > 1
}
):
schema_issues["duplicated_columns"] += duplicated_columns

# Testing mandatory fields:
required_fields = [x.name for x in expected_schema if not x.nullable]
if missing_required_fields := [
f"{parent_field_name}{req}"
for req in required_fields
if not any(field.name == req for field in observed_schema)
]:
schema_issues["missing_mandatory_columns"] += missing_required_fields

# Converting schema to dictionaries for easier comparison:
observed_schema_dict = {field.name: field for field in observed_schema}
expected_schema_dict = {field.name: field for field in expected_schema}

# Testing optional fields and types:
for field_name, field in observed_schema_dict.items():
# Testing observed field name, if name is not matched, no further tests are needed:
if field_name not in expected_schema_dict:
schema_issues["unexpected_columns"].append(
f"{parent_field_name}{field_name}"
)
continue

# When we made sure the field is in both schemas, extracting field type information:
observed_type = field.dataType
observed_type_name = field.dataType.typeName()

expected_type = expected_schema_dict[field_name].dataType
expected_type_name = expected_schema_dict[field_name].dataType.typeName()

# Flagging non-matching types if types don't match, jumping to next field:
if observed_type_name != expected_type_name:
schema_issues["columns_with_non_matching_type"].append(
f'For column "{parent_field_name}{field_name}" found {observed_type_name} instead of {expected_type_name}'
)
continue

# If column is a struct, resolve nesting:
if observed_type_name == "struct":
schema_issues = compare_struct_schemas(
observed_type,
expected_type,
f"{parent_field_name}{field_name}.",
schema_issues,
)
# If column is an array, resolve nesting:
elif observed_type_name == "array":
schema_issues = compare_array_schemas(
observed_type,
expected_type,
f"{parent_field_name}{field_name}[]",
schema_issues,
)

return schema_issues
7 changes: 7 additions & 0 deletions src/gentropy/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,10 +121,16 @@ class GWASCatalogSumstatsPreprocessConfig(StepConfig):
class EqtlCatalogueConfig(StepConfig):
"""eQTL Catalogue step configuration."""

session: Any = field(
default_factory=lambda: {
"start_hail": True,
}
)
eqtl_catalogue_paths_imported: str = MISSING
eqtl_catalogue_study_index_out: str = MISSING
eqtl_catalogue_credible_sets_out: str = MISSING
mqtl_quantification_methods_blacklist: list[str] = field(default_factory=lambda: [])
eqtl_lead_pvalue_threshold: float = 1e-3
_target_: str = "gentropy.eqtl_catalogue.EqtlCatalogueStep"


Expand Down Expand Up @@ -168,6 +174,7 @@ class FinngenFinemappingConfig(StepConfig):
_target_: str = (
"gentropy.finngen_finemapping_ingestion.FinnGenFinemappingIngestionStep"
)
finngen_finemapping_lead_pvalue_threshold: float = 1e-5


@dataclass
Expand Down
Loading

0 comments on commit 4d97e66

Please sign in to comment.