diff --git a/adtl/__init__.py b/adtl/__init__.py index 2959235..66c67b7 100644 --- a/adtl/__init__.py +++ b/adtl/__init__.py @@ -104,19 +104,25 @@ def get_value_unhashed(row: StrDict, rule: Rule, ctx: Context = None) -> Any: value = getattr(tf, transformation)(value, *params) except AttributeError: raise AttributeError( - f"Error using a data transformation: Function {transformation} has not been defined." + f"Error using a data transformation: Function {transformation} " + "has not been defined." ) else: try: value = getattr(tf, transformation)(value) except AttributeError: raise AttributeError( - f"Error using a data transformation: Function {transformation} has not been defined." + f"Error using a data transformation: Function {transformation} " + "has not been defined." ) return value if value == "": return None if "values" in rule: + if rule.get("caseInsensitive") and isinstance(value, str): + value = value.lower() + rule["values"] = {k.lower(): v for k, v in rule["values"].items()} + if rule.get("ignoreMissingKey"): value = rule["values"].get(value, value) else: @@ -127,10 +133,10 @@ def get_value_unhashed(row: StrDict, rule: Rule, ctx: Context = None) -> Any: assert "source_date" not in rule and "date" not in rule source_unit = get_value(row, rule["source_unit"]) unit = rule["unit"] - if type(source_unit) != str: + if not isinstance(source_unit, str): logging.debug( - f"Error converting source_unit {source_unit} to {unit!r} with rule: {rule}, " - "defaulting to assume source_unit is {unit}" + f"Error converting source_unit {source_unit} to {unit!r} with " + "rule: {rule}, defaulting to assume source_unit is {unit}" ) return float(value) try: @@ -198,7 +204,8 @@ def parse_if( cast_value = type(value)(attr_value) except ValueError: logging.debug( - f"Error when casting value {attr_value!r} with rule: {rule}, defaulting to False" + f"Error when casting value {attr_value!r} with rule: {rule}, defaulting" + " to False" ) return False if cmp == ">": @@ -227,7 +234,8 @@ def parse_if( cast_value = type(value)(attr_value) except ValueError: logging.debug( - f"Error when casting value {attr_value!r} with rule: {rule}, defaulting to False" + f"Error when casting value {attr_value!r} with rule: {rule}, defaulting" + " to False" ) return False return cast_value == value @@ -371,7 +379,8 @@ def replace_val( for_expr = match.pop("for") if not isinstance(for_expr, dict): raise ValueError( - f"for expression {for_expr!r} is not a dictionary of variables to list of values or a range" + f"for expression {for_expr!r} is not a dictionary of variables to list " + "of values or a range" ) # Expand ranges when available @@ -390,7 +399,8 @@ def replace_val( pass else: raise ValueError( - f"for expression {for_expr!r} can only have lists or ranges for variables" + f"for expression {for_expr!r} can only have lists or ranges for " + "variables" ) loop_vars = sorted(for_expr.keys()) loop_assignments = [ @@ -564,12 +574,14 @@ def __init__( res = requests.get(schema) if res.status_code != 200: logging.warning( - f"Could not fetch schema for table {table!r}, will not validate" + f"Could not fetch schema for table {table!r}, will not " + "validate" ) continue except ConnectionError: # pragma: no cover logging.warning( - f"Could not fetch schema for table {table!r}, will not validate" + f"Could not fetch schema for table {table!r}, will not " + "validate" ) continue self.schemas[table] = make_fields_optional( @@ -618,7 +630,8 @@ def validate_spec(self): ) if group_field is not None and aggregation != "lastNotNull": raise ValueError( - f"groupBy needs aggregation=lastNotNull to be set for table: {table}" + "groupBy needs aggregation=lastNotNull to be set for table: " + f"{table}" ) def _set_field_names(self): @@ -632,7 +645,8 @@ def _set_field_names(self): else: if table not in self.schemas: print( - f"Warning: no schema found for {table!r}, field names may be incomplete!" + f"Warning: no schema found for {table!r}, field names may be " + "incomplete!" ) self.fieldnames[table] = list( self.tables[table].get("common", {}).keys() @@ -734,7 +748,8 @@ def update_table(self, table: str, row: StrDict): if combined_type in ["all", "any", "min", "max"]: values = [existing_value, value] - # normally calling eval() is a bad idea, but here values are restricted, so okay + # normally calling eval() is a bad idea, but here + # values are restricted, so okay self.data[table][group_key][attr] = eval(combined_type)( values ) @@ -812,7 +827,8 @@ def parse_rows(self, rows: Iterable[StrDict], skip_validation=False): """Transform rows from an iterable according to specification Args: - rows: Iterable of rows, specified as a dictionary of (field name, field value) pairs + rows: Iterable of rows, specified as a dictionary of + (field name, field value) pairs skip_validation: Whether to skip validation, default off Returns: @@ -879,7 +895,8 @@ def write_csv( Args: table: Table that should be written to CSV - output: (optional) Output file name. If not specified, defaults to parser name + table name + output: (optional) Output file name. If not specified, defaults to parser + name + table name with a csv suffix. """ @@ -960,8 +977,9 @@ def show_report(self): print("|---------------|-------|-------|----------------|") for table in self.report["total"]: print( - f"|{table:14s}\t|{self.report['total_valid'][table]}\t|{self.report['total'][table]}\t" - f"|{self.report['total_valid'][table]/self.report['total'][table]:%} |" + f"|{table:14s}\t|{self.report['total_valid'][table]}\t" + f"|{self.report['total'][table]}\t" + f"|{self.report['total_valid'][table]/self.report['total'][table]:%} |" # noqa:E501 ) print() for table in self.report["validation_errors"]: diff --git a/docs/_static/style.css b/docs/_static/style.css index 27cafa6..3bd4c00 100644 --- a/docs/_static/style.css +++ b/docs/_static/style.css @@ -111,4 +111,3 @@ div.bodywrapper h4 { padding-right: 0; } } - diff --git a/docs/conf.py b/docs/conf.py index 2cdcbce..57ba44e 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -17,7 +17,7 @@ project = "adtl" copyright = "2023, Global.health" -release = "0.5.0" +release = "0.6.0" # -- General configuration --------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration @@ -38,7 +38,7 @@ html_theme = "better" html_static_path = ["_static"] -html_theme_path=[better.better_theme_path] +html_theme_path = [better.better_theme_path] html_short_title = "Home" html_theme_options = { @@ -46,4 +46,4 @@ "sidebarwidth": "25rem", "cssfiles": ["_static/style.css"], "showheader": False, -} \ No newline at end of file +} diff --git a/docs/specification.md b/docs/specification.md index ae098a2..ec46008 100644 --- a/docs/specification.md +++ b/docs/specification.md @@ -287,6 +287,22 @@ values = { 1 = true, 2 = false } description = "Dementia" ``` +If the data for this field has a range of different capitalisations and you wish to +capture them all without specifying each variant, you can add `caseInsensitive = true` +to the rule: + +```toml +[table.sex_at_birth] +field = "sex" +values = { homme = "male", femme = "female" } +caseInsensitive = true +``` + +When the parser encounters e.g. `Homme` or `FEMME` in the data it will still match to +`male` and `female` respectively. The parser will still ignore different spellings, e.g. +`Home` will return `null`. + + ### Combined type Refers to multiple fields in the source format. Requires diff --git a/pyproject.toml b/pyproject.toml index 2b91465..541ad07 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,7 @@ packages = ["adtl"] [project] name = "adtl" -version = "0.5.0" +version = "0.6.0" description = "Another data transformation language" authors = [ {name = "Abhishek Dasgupta", email = "abhishek.dasgupta@dtc.ox.ac.uk"}, diff --git a/tests/parsers/stop-overwriting.toml b/tests/parsers/stop-overwriting.toml index 7ebdc0b..65d7b4c 100644 --- a/tests/parsers/stop-overwriting.toml +++ b/tests/parsers/stop-overwriting.toml @@ -12,13 +12,13 @@ [visit.subject_id] field = "subjid" description = "Subject ID" - + [visit.earliest_admission] combinedType = "min" fields = [ { field = "first_admit" }, ] - + [visit.start_date] combinedType = "firstNonNull" fields = [ @@ -44,4 +44,3 @@ { field = "overall_antiviral_dc___2", values = { 1 = "Lopinavir" } }, { field = "overall_antiviral_dc___3", values = { 1 = "Interferon" } }, ] - diff --git a/tests/sources/stop-overwriting.csv b/tests/sources/stop-overwriting.csv index 2455376..a776030 100644 --- a/tests/sources/stop-overwriting.csv +++ b/tests/sources/stop-overwriting.csv @@ -8,4 +8,4 @@ subjid,redcap,first_admit,enrolment,icu_admission_date,daily_antiviral_type___1, 2,day1,,,2020-11-30,0,1,0,0,0,0 3,admit,,2020-02-20,,0,0,0,0,0,0 3,discharge,,,,0,0,0,0,1,1 -3,day1,,,,1,0,0,0,0,0 \ No newline at end of file +3,day1,,,,1,0,0,0,0,0 diff --git a/tests/test_parser.py b/tests/test_parser.py index 7ea75c8..ca48b3f 100644 --- a/tests/test_parser.py +++ b/tests/test_parser.py @@ -66,6 +66,12 @@ "ignoreMissingKey": True, } +RULE_CASEINSENSITIVE = { + "field": "diabetes_mhyn", + "values": {"Type 1": "E10", "TYPE 2": "E11"}, # ICD-10 codes + "caseInsensitive": True, +} + ROW_CONDITIONAL = {"outcome_date": "2022-01-01", "outcome_type": 4} RULE_CONDITIONAL_OK = {"field": "outcome_date", "if": {"outcome_type": 4}} RULE_CONDITIONAL_FAIL = {"field": "outcome_date", "if": {"outcome_type": {"<": 4}}} @@ -445,6 +451,8 @@ def _subdict(d: Dict, keys: Iterable[Any]) -> Dict[str, Any]: (({"first": "", "second": False}, RULE_COMBINED_FIRST_NON_NULL), False), (({"diabetes_mhyn": "type 1"}, RULE_IGNOREMISSINGKEY), "E10"), (({"diabetes_mhyn": "gestational"}, RULE_IGNOREMISSINGKEY), "gestational"), + (({"diabetes_mhyn": "type 2"}, RULE_CASEINSENSITIVE), "E11"), + (({"diabetes_mhyn": "TYPE 1"}, RULE_CASEINSENSITIVE), "E10"), ((ROW_CONDITIONAL, RULE_CONDITIONAL_OK), "2022-01-01"), ((ROW_CONDITIONAL, RULE_CONDITIONAL_FAIL), None), ((ROW_UNIT_MONTH, RULE_UNIT), 1.5),