diff --git a/liiatools/__main__.py b/liiatools/__main__.py index c09ae46c..91f47439 100644 --- a/liiatools/__main__.py +++ b/liiatools/__main__.py @@ -2,7 +2,7 @@ from liiatools.annex_a_pipeline.cli import annex_a -# from liiatools.datasets.cin_census.cin_cli import cin_census +from liiatools.cin_census_pipeline.cli import cin_census from liiatools.csww_pipeline.cli import csww from liiatools.ssda903_pipeline.cli import s903 from liiatools.s251_pipeline.cli import s251 @@ -14,7 +14,7 @@ def cli(): cli.add_command(annex_a) -# cli.add_command(cin_census) +cli.add_command(cin_census) cli.add_command(s903) cli.add_command(csww) cli.add_command(s251) diff --git a/liiatools/cin_census_pipeline/_reports_assessment_factors.py b/liiatools/cin_census_pipeline/_reports_assessment_factors.py index 80671af9..e32ed47f 100644 --- a/liiatools/cin_census_pipeline/_reports_assessment_factors.py +++ b/liiatools/cin_census_pipeline/_reports_assessment_factors.py @@ -2,10 +2,10 @@ def expanded_assessment_factors( - data: pd.DataFrame, column_name="AssessmentFactor", prefix: str = "" + data: pd.DataFrame, column_name="Factors", prefix: str = "" ) -> pd.DataFrame: """ - Expects to receive a dataframe with a column named 'AssessmentFactor' containing a comma-separated list of values. + Expects to receive a dataframe with a column named "Factors" containing a comma-separated list of values. Expands these values into a "one-hot" encoding of the values. Can optionally prefix the column names with a prefix string. diff --git a/liiatools/cin_census_pipeline/_reports_s47_journeys.py b/liiatools/cin_census_pipeline/_reports_s47_journeys.py new file mode 100644 index 00000000..e69de29b diff --git a/liiatools/cin_census_pipeline/cli.py b/liiatools/cin_census_pipeline/cli.py new file mode 100644 index 00000000..422ebeb4 --- /dev/null +++ b/liiatools/cin_census_pipeline/cli.py @@ -0,0 +1,57 @@ +import logging + +import click as click +import click_log +from fs import open_fs + +from liiatools.common.reference import authorities + +from .pipeline import process_session + +log = logging.getLogger() +click_log.basic_config(log) + + +@click.group() +def cin_census(): + """Functions for cleaning, minimising and aggregating CIN census files""" + pass + + +@cin_census.command() +@click.option( + "--la-code", + "-c", + required=True, + type=click.Choice(authorities.codes, case_sensitive=False), + help="Local authority code", +) +@click.option( + "--output", + "-o", + required=True, + type=click.Path(file_okay=False, writable=True), + help="Output folder", +) +@click.option( + "--input", + "-i", + type=click.Path(exists=True, file_okay=False, readable=True), +) +@click_log.simple_verbosity_option(log) +def pipeline(input, la_code, output): + """ + Runs the full pipeline on a file or folder + :param input: The path to the input folder + :param la_code: A three-letter string for the local authority depositing the file + :param output: The path to the output folder + :return: None + """ + + # Source FS is the filesystem containing the input files + source_fs = open_fs(input) + + # Get the output filesystem + output_fs = open_fs(output) + + process_session(source_fs, output_fs, la_code) diff --git a/liiatools/cin_census_pipeline/pipeline.py b/liiatools/cin_census_pipeline/pipeline.py index cccc055e..b61c4be6 100644 --- a/liiatools/cin_census_pipeline/pipeline.py +++ b/liiatools/cin_census_pipeline/pipeline.py @@ -1,18 +1,29 @@ -from fs import FS +import logging +from fs import open_fs +from fs.base import FS from liiatools.common import pipeline as pl from liiatools.common.archive import DataframeArchive from liiatools.common.constants import ProcessNames, SessionNames from liiatools.common.data import ( - DataContainer, ErrorContainer, FileLocator, PipelineConfig, ProcessResult, - TableConfig, ) from liiatools.common.transform import degrade_data, enrich_data, prepare_export +from liiatools.cin_census_pipeline.spec import ( + load_pipeline_config, + load_schema, + load_schema_path, +) + +from liiatools.cin_census_pipeline.stream_pipeline import task_cleanfile + + +logger = logging.getLogger() + def process_file( file_locator: FileLocator, @@ -20,6 +31,14 @@ def process_file( pipeline_config: PipelineConfig, la_code: str, ) -> ProcessResult: + """ + Clean, enrich and degrade data + :param file_locator: The pointer to a file in a virtual filesystem + :param session_folder: The path to the session folder + :param pipeline_config: The pipeline configuration + :param la_code: A three-letter string for the local authority depositing the file + :return: A class containing a DataContainer and ErrorContainer + """ errors = ErrorContainer() year = pl.discover_year(file_locator) if year is None: @@ -35,10 +54,14 @@ def process_file( # We save these files based on the session UUID - so UUID must exist uuid = file_locator.meta["uuid"] + # Load schema and set on processing metadata + schema = load_schema(year=year) + schema_path = load_schema_path(year=year) metadata = dict(year=year, schema=schema, la_code=la_code) + # Normalise the data and export to the session 'cleaned' folder try: - cleanfile_result = task_cleanfile(file_locator, schema) + cleanfile_result = task_cleanfile(file_locator, schema, schema_path) except Exception as e: logger.exception(f"Error cleaning file {file_locator.name}") errors.append( @@ -50,8 +73,43 @@ def process_file( ) return ProcessResult(data=None, errors=errors) + # Export the cleaned data to the session 'cleaned' folder + cleanfile_result.data.export( + session_folder, f"{SessionNames.CLEANED_FOLDER}/{uuid}_", "parquet" + ) + errors.extend(cleanfile_result.errors) + + # Enrich the data and export to the session 'enriched' folder + enrich_result = enrich_data(cleanfile_result.data, pipeline_config, metadata) + enrich_result.data.export( + session_folder, f"{SessionNames.ENRICHED_FOLDER}/{uuid}_", "parquet" + ) + errors.extend(enrich_result.errors) + + # Degrade the data and export to the session 'degraded' folder + degraded_result = degrade_data(enrich_result.data, pipeline_config, metadata) + degraded_result.data.export( + session_folder, f"{SessionNames.DEGRADED_FOLDER}/{uuid}_", "parquet" + ) + errors.extend(degraded_result.errors) + + errors.set_property("filename", file_locator.name) + errors.set_property("uuid", uuid) + + return ProcessResult(data=degraded_result.data, errors=errors) + def process_session(source_fs: FS, output_fs: FS, la_code: str): + """ + Runs the full pipeline on a file or folder + :param source_fs: File system containing the input files + :param output_fs: File system for the output files + :param la_code: A three-letter string for the local authority depositing the file + :return: None + """ + # Before we start - load configuration for this dataset + pipeline_config = load_pipeline_config() + # Ensure all processing folders exist pl.create_process_folders(output_fs) @@ -88,3 +146,17 @@ def process_session(source_fs: FS, output_fs: FS, la_code: str): current_data.export( output_fs.opendir(ProcessNames.CURRENT_FOLDER), "cin_cencus_current_", "csv" ) + + # Create the different reports + export_folder = output_fs.opendir(ProcessNames.EXPORT_FOLDER) + for report in ["PAN"]: + report_data = prepare_export(current_data, pipeline_config, profile=report) + report_folder = export_folder.makedirs(report, recreate=True) + report_data.data.export(report_folder, "cin_census_", "csv") + + +process_session( + open_fs(r"C:\Users\patrick.troy\OneDrive - Social Finance Ltd\Work\LIIA\LIIA tests\CIN\pipeline\input"), + open_fs(r"C:\Users\patrick.troy\OneDrive - Social Finance Ltd\Work\LIIA\LIIA tests\CIN\pipeline\output"), + la_code="BAR" + ) diff --git a/liiatools/cin_census_pipeline/spec/CIN_schema_2017.xsd b/liiatools/cin_census_pipeline/spec/CIN_schema_2017.xsd index 410576f3..d558bc1e 100644 --- a/liiatools/cin_census_pipeline/spec/CIN_schema_2017.xsd +++ b/liiatools/cin_census_pipeline/spec/CIN_schema_2017.xsd @@ -76,7 +76,7 @@ - + @@ -143,9 +143,9 @@ - - - + + + @@ -161,22 +161,23 @@ + + + + + + - - - - - - + False + True - - + @@ -279,7 +280,7 @@ - + Neglect Physical abuse diff --git a/liiatools/cin_census_pipeline/spec/CIN_schema_2018.xsd b/liiatools/cin_census_pipeline/spec/CIN_schema_2018.xsd index 410576f3..fa859811 100644 --- a/liiatools/cin_census_pipeline/spec/CIN_schema_2018.xsd +++ b/liiatools/cin_census_pipeline/spec/CIN_schema_2018.xsd @@ -76,7 +76,7 @@ - + @@ -143,9 +143,9 @@ - - - + + + @@ -162,21 +162,22 @@ + + + + + + - - - - - - + False + True - - + @@ -279,7 +280,7 @@ - + Neglect Physical abuse diff --git a/liiatools/cin_census_pipeline/spec/CIN_schema_2019.xsd b/liiatools/cin_census_pipeline/spec/CIN_schema_2019.xsd index 410576f3..fa859811 100644 --- a/liiatools/cin_census_pipeline/spec/CIN_schema_2019.xsd +++ b/liiatools/cin_census_pipeline/spec/CIN_schema_2019.xsd @@ -76,7 +76,7 @@ - + @@ -143,9 +143,9 @@ - - - + + + @@ -162,21 +162,22 @@ + + + + + + - - - - - - + False + True - - + @@ -279,7 +280,7 @@ - + Neglect Physical abuse diff --git a/liiatools/cin_census_pipeline/spec/CIN_schema_2020.xsd b/liiatools/cin_census_pipeline/spec/CIN_schema_2020.xsd index 410576f3..fa859811 100644 --- a/liiatools/cin_census_pipeline/spec/CIN_schema_2020.xsd +++ b/liiatools/cin_census_pipeline/spec/CIN_schema_2020.xsd @@ -76,7 +76,7 @@ - + @@ -143,9 +143,9 @@ - - - + + + @@ -162,21 +162,22 @@ + + + + + + - - - - - - + False + True - - + @@ -279,7 +280,7 @@ - + Neglect Physical abuse diff --git a/liiatools/cin_census_pipeline/spec/CIN_schema_2021.xsd b/liiatools/cin_census_pipeline/spec/CIN_schema_2021.xsd index b6175dc8..05716458 100644 --- a/liiatools/cin_census_pipeline/spec/CIN_schema_2021.xsd +++ b/liiatools/cin_census_pipeline/spec/CIN_schema_2021.xsd @@ -76,7 +76,7 @@ - + @@ -143,9 +143,9 @@ - - - + + + @@ -162,21 +162,22 @@ + + + + + + - - - - - - + False + True - - + @@ -279,7 +280,7 @@ - + Neglect Physical abuse diff --git a/liiatools/cin_census_pipeline/spec/CIN_schema_2022.xsd b/liiatools/cin_census_pipeline/spec/CIN_schema_2022.xsd index 08833eb9..772a8587 100644 --- a/liiatools/cin_census_pipeline/spec/CIN_schema_2022.xsd +++ b/liiatools/cin_census_pipeline/spec/CIN_schema_2022.xsd @@ -76,7 +76,7 @@ - + @@ -151,9 +151,9 @@ - - - + + + @@ -170,21 +170,22 @@ + + + + + + - - - - - - + False + True - - + @@ -287,7 +288,7 @@ - + Neglect Physical abuse diff --git a/liiatools/cin_census_pipeline/spec/CIN_schema_2023.xsd b/liiatools/cin_census_pipeline/spec/CIN_schema_2023.xsd index 23a594d0..c9f973f5 100644 --- a/liiatools/cin_census_pipeline/spec/CIN_schema_2023.xsd +++ b/liiatools/cin_census_pipeline/spec/CIN_schema_2023.xsd @@ -76,7 +76,7 @@ - + @@ -151,9 +151,9 @@ - - - + + + @@ -170,21 +170,22 @@ + + + + + + - - - - - - + False + True - - + @@ -288,7 +289,7 @@ - + Neglect Physical abuse diff --git a/liiatools/cin_census_pipeline/spec/CIN_schema_2024.xsd b/liiatools/cin_census_pipeline/spec/CIN_schema_2024.xsd index 23a594d0..c9f973f5 100644 --- a/liiatools/cin_census_pipeline/spec/CIN_schema_2024.xsd +++ b/liiatools/cin_census_pipeline/spec/CIN_schema_2024.xsd @@ -76,7 +76,7 @@ - + @@ -151,9 +151,9 @@ - - - + + + @@ -170,21 +170,22 @@ + + + + + + - - - - - - + False + True - - + @@ -288,7 +289,7 @@ - + Neglect Physical abuse diff --git a/liiatools/cin_census_pipeline/spec/__init__.py b/liiatools/cin_census_pipeline/spec/__init__.py index f11bced6..61375aa2 100644 --- a/liiatools/cin_census_pipeline/spec/__init__.py +++ b/liiatools/cin_census_pipeline/spec/__init__.py @@ -3,9 +3,23 @@ import xmlschema +from pydantic_yaml import parse_yaml_file_as + +from liiatools.common.data import PipelineConfig + SCHEMA_DIR = Path(__file__).parent +@lru_cache +def load_pipeline_config(): + with open(SCHEMA_DIR / "pipeline.yml", "rt") as FILE: + return parse_yaml_file_as(PipelineConfig, FILE) + + @lru_cache def load_schema(year: int) -> xmlschema.XMLSchema: return xmlschema.XMLSchema(SCHEMA_DIR / f"CIN_schema_{year:04d}.xsd") + + +def load_schema_path(year: int) -> Path: + return Path(SCHEMA_DIR, f"CIN_schema_{year:04d}.xsd") diff --git a/liiatools/cin_census_pipeline/spec/pipeline.yml b/liiatools/cin_census_pipeline/spec/pipeline.yml new file mode 100644 index 00000000..c56af0e3 --- /dev/null +++ b/liiatools/cin_census_pipeline/spec/pipeline.yml @@ -0,0 +1,90 @@ +table_list: +- id: CIN + retain: + - PAN + columns: + - id: LAchildID + type: string + unique_key: true + - id: Date + type: date + unique_key: true + - id: Type + type: string + unique_key: true + - id: CINreferralDate + type: date + - id: ReferralSource + type: category + unique_key: true + - id: PrimaryNeedCode + type: category + - id: CINclosureDate + type: date + - id: ReasonForClosure + type: category + - id: DateOfInitialCPC + type: date + - id: ReferralNFA + type: category + - id: CINPlanStartDate + type: date + - id: CINPlanEndDate + type: date + - id: S47ActualStartDate + type: date + - id: InitialCPCtarget + type: date + sort: 2 + - id: ICPCnotRequired + type: category + - id: AssessmentActualStartDate + type: date + - id: AssessmentInternalReviewDate + type: date + sort: 1 + - id: AssessmentAuthorisationDate + type: date + - id: Factors + type: category + - id: CPPstartDate + type: date + - id: CPPendDate + type: date + - id: InitialCategoryOfAbuse + type: category + - id: LatestCategoryOfAbuse + type: category + - id: NumberOfPreviousCPP + type: numeric + - id: UPN + type: string + - id: FormerUPN + type: string + - id: UPNunknown + type: category + - id: PersonBirthDate + type: date + degrade: first_of_month + - id: ExpectedPersonBirthDate + type: date + degrade: first_of_month + - id: GenderCurrent + type: category + - id: PersonDeathDate + type: date + degrade: first_of_month + - id: PersonSchoolYear + type: numeric + enrich: school_year + - id: Ethnicity + type: category + - id: Disabilities + type: category + - id: LA + type: string + enrich: la_name + - id: Year + type: numeric + enrich: year + sort: 0 \ No newline at end of file diff --git a/liiatools/cin_census_pipeline/stream_filters.py b/liiatools/cin_census_pipeline/stream_filters.py index 997adb35..efe648a9 100644 --- a/liiatools/cin_census_pipeline/stream_filters.py +++ b/liiatools/cin_census_pipeline/stream_filters.py @@ -1,205 +1,59 @@ import logging -from collections import defaultdict, deque -from typing import List +from pathlib import Path -import xmlschema from sfdata_stream_parser import events from sfdata_stream_parser.checks import type_check -from sfdata_stream_parser.collectors import block_check, collector from sfdata_stream_parser.filters.generic import pass_event, streamfilter -logger = logging.getLogger(__name__) - - -@streamfilter(check=type_check(events.TextNode), fail_function=pass_event) -def strip_text(event): - """ - Strips surrounding whitespaces from :class:`sfdata_stream_parser.events.TextNode`. If the event does - not have a text property then this filter fails silently. - - If there is no content at all, then the node is not returned. - """ - if not hasattr(event, "text"): - return event - - if event.text is None: - return None - - text = event.text.strip() - if len(text) > 0: - return event.from_event(event, text=text) - else: - return None - - -@streamfilter(default_args=lambda: {"context": []}) -def add_context(event, context: List[str]): - """ - Adds 'context' to XML structures. For each :class:`sfdata_stream_parser.events.StartElement` the tag name is - added to a 'context' tuple, and for each :class:`sfdata_stream_parser.events.EndElement` the context is popped. - - For all other events, the context tuple is set as-is. - - Provides: context - """ - if isinstance(event, events.StartElement): - context.append(event.tag) - local_context = tuple(context) - elif isinstance(event, events.EndElement): - local_context = tuple(context) - context.pop() - else: - local_context = tuple(context) - - return event.from_event(event, context=local_context) +from liiatools.common.spec.__data_schema import Column, Numeric +from liiatools.common.stream_filters import _create_category_spec, _create_regex_spec, _create_numeric_spec - -@streamfilter() -def add_schema(event, schema: xmlschema.XMLSchema): - """ - Requires each event to have event.context as set by :func:`add_context` - - Based on the context (a tuple of element tags) it will set path which is the - derived path (based on the context tags) joined by '/' and schema holding the - corresponding schema element, if found. - - Provides: path, schema - """ - assert ( - event.context - ), "This filter required event.context to be set - see add_context" - path = "/".join(event.context) - tag = event.context[-1] - el = schema.get_element(tag, path) - return event.from_event(event, path=path, schema=el) +logger = logging.getLogger(__name__) -def inherit_LAchildID(stream): +@streamfilter( + check=type_check(events.TextNode), + fail_function=pass_event, + error_function=pass_event, +) +def add_column_spec(event, schema_path: Path): """ - Apply the LAchildID to all elements within + Add a Column class containing schema attributes to an event object based on its type and occurrence - :param stream: A filtered list of event objects - :return: An updated list of event objects + :param event: An event object with a schema attribute + :param schema_path: The path to the schema file + :return: A new event object with a column_spec attribute, or the original event object if no schema is found """ - child_id = child_events = None - for event in stream: - if isinstance(event, events.StartElement) and event.tag == "Child": - child_events = [event] - elif isinstance(event, events.EndElement) and event.tag == "Child": - child_events.append(event) - yield from [e.from_event(e, LAchildID=child_id) for e in child_events] - child_id = child_events = None - elif isinstance(event, events.TextNode): - try: - child_events.append(event) - if event.schema.name == "LAchildID": - child_id = event.text - except ( - AttributeError - ): # Raised in case there is no LAchildID or event is None - pass - elif child_events: - child_events.append(event) - else: - yield event - - -def validate_elements(stream): - error_dict = {} - for event in stream: - # Only validate root element - if isinstance(event, events.StartElement) and event.node.getparent() is None: - error_list = list(event.schema.iter_errors(event.node)) - error_dict = defaultdict(list) - for e in error_list: - error_dict[e.elem].append(e) - if isinstance(event, events.StartElement): - if event.node in error_dict: - event = event.from_event( - event, valid=False, validation_errors=error_dict[event.node] - ) - yield event + column_spec = Column() + if event.schema.occurs[0] == 1: + column_spec.canbeblank = False -@streamfilter(check=lambda x: True) -def counter(event, counter_check, value_error, structural_error, blank_error): - """ - Count the invalid simple nodes storing their name and LAchildID data - - :param event: A filtered list of event objects - :param counter_check: A function to identify which events to check - :param value_error: An empty list to store the invalid element information - :param structural_error: An empty list to store the invalid structure information - :param blank_error: An empty list to store the blank element information - :return: The same filtered list of event objects - """ - if counter_check(event) and len(event.node) == 0: + config_type = event.schema.type.name + if config_type is not None: + if config_type[-4:] == "type": + column_spec.category = _create_category_spec(config_type, schema_path) + if config_type in ["positiveintegertype"]: + column_spec.numeric = _create_numeric_spec(config_type, schema_path) + if config_type in ["upntype"]: + column_spec.string = "regex" + column_spec.cell_regex = _create_regex_spec(config_type, schema_path) if ( - getattr(event, "LAchildID", None) is not None - ): # In case there are errors in the
node as none - # of these will have an LAchildID assigned - if hasattr(event, "validation_message"): - blank_error.append( - f"LAchildID: {event.LAchildID}, Node: {event.schema.name}" - ) - elif hasattr(event.schema, "name"): - value_error.append( - f"LAchildID: {event.LAchildID}, Node: {event.schema.name}" - ) - else: - structural_error.append( - f"LAchildID: {event.LAchildID}, Node: {event.tag}" - ) - else: - if hasattr(event, "validation_message"): - blank_error.append(f"LAchildID: blank, Node: {event.schema.name}") - elif hasattr(event.schema, "name"): - value_error.append(f"LAchildID: blank, Node: {event.schema.name}") - else: - structural_error.append(f"LAchildID: blank, Node: {event.tag}") - return event - - -@streamfilter(check=type_check(events.TextNode), fail_function=pass_event) -def convert_true_false(event): - """ - Search for any events that have the schema type="yesnotype" and convert any values of false to 0 and true to 1 - - :param event: A filtered list of event objects - :return: An updated list of event objects - """ - if hasattr(event, "schema"): - if event.schema.type.name == "yesnotype": - if event.text.lower() == "false": - event = event.from_event(event, text="0") - elif event.text.lower() == "true": - event = event.from_event(event, text="1") - return event - - -@collector(check=block_check(events.StartElement), receive_stream=True) -def remove_invalid(stream): - """ - Filters out events with the given tag name if they are not valid - - :param stream: A filtered list of event objects - :return: An updated list of event objects - """ - stream = deque(stream) - first = stream.popleft() - last = stream.pop() - - is_valid = getattr(first, "valid", True) - is_content = first.schema.type.is_simple() - - if is_content and not is_valid: - messages = ", ".join([e.reason for e in first.validation_errors]) - logger.info("Removing invalid content in tag %s: %s", first.tag, messages) - return - - yield first - - if len(stream) > 0: - yield from remove_invalid(stream) + config_type == "{http://www.w3.org/2001/XMLSchema}date" + ): + column_spec.date = "%Y-%m-%d" + if ( + config_type == "{http://www.w3.org/2001/XMLSchema}dateTime" + ): + column_spec.date = "%Y-%m-%dT%H:%M:%S" + if config_type in [ + "{http://www.w3.org/2001/XMLSchema}integer", + "{http://www.w3.org/2001/XMLSchema}gYear" + ]: + column_spec.numeric = Numeric(type="integer") + if config_type == "{http://www.w3.org/2001/XMLSchema}string": + column_spec.string = "alphanumeric" + else: + column_spec.string = "alphanumeric" - yield last + return event.from_event(event, column_spec=column_spec) diff --git a/liiatools/cin_census_pipeline/stream_parse.py b/liiatools/cin_census_pipeline/stream_parse.py deleted file mode 100644 index 1408b5a6..00000000 --- a/liiatools/cin_census_pipeline/stream_parse.py +++ /dev/null @@ -1,35 +0,0 @@ -from sfdata_stream_parser.events import ( - StartElement, - EndElement, - TextNode, - CommentNode, - ProcessingInstructionNode, -) - -try: - from lxml import etree -except ImportError: - pass - - -def dom_parse(source, **kwargs): - """ - Equivalent of the xml parse included in the sfdata_stream_parser package, but uses the ET DOM - and allows direct DOM manipulation. - """ - parser = etree.iterparse(source, events=("start", "end", "comment", "pi"), **kwargs) - for action, elem in parser: - if action == "start": - yield StartElement(tag=elem.tag, attrib=elem.attrib, node=elem) - if elem.text: - yield TextNode(text=elem.text) - elif action == "end": - yield EndElement(tag=elem.tag, node=elem) - if elem.tail: - yield TextNode(text=elem.tail) - elif action == "comment": - yield CommentNode(text=elem.text, node=elem) - elif action == "pi": - yield ProcessingInstructionNode(name=elem.target, text=elem.text, node=elem) - else: - raise ValueError(f"Unknown event: {action}") diff --git a/liiatools/cin_census_pipeline/stream_pipeline.py b/liiatools/cin_census_pipeline/stream_pipeline.py index 28d3ff12..472e1108 100644 --- a/liiatools/cin_census_pipeline/stream_pipeline.py +++ b/liiatools/cin_census_pipeline/stream_pipeline.py @@ -1,29 +1,54 @@ -import tablib +from pathlib import Path from xmlschema import XMLSchema +import pandas as pd + +from sfdata_stream_parser.filters import generic + +from liiatools.common.data import FileLocator, ProcessResult, DataContainer +from liiatools.common import stream_filters as stream_functions +from liiatools.common.stream_parse import dom_parse -from liiatools.common.data import FileLocator, ProcessResult from liiatools.cin_census_pipeline import stream_record -from liiatools.cin_census_pipeline.stream_parse import dom_parse from . import stream_filters as filters -# TODO: Should return a ProcessResult with a dataframe, not tablib - -def task_cleanfile(src_file: FileLocator, schema: XMLSchema) -> tablib.Dataset: +def task_cleanfile( + src_file: FileLocator, schema: XMLSchema, schema_path: Path +) -> ProcessResult: + """ + Clean input cin census xml files according to schema and output clean data and errors + :param src_file: The pointer to a file in a virtual filesystem + :param schema: The data schema + :param schema_path: Path to the data schema + :return: A class containing a DataContainer and ErrorContainer + """ with src_file.open("rb") as f: - stream = dom_parse(f) - - stream = filters.strip_text(stream) - stream = filters.add_context(stream) - stream = filters.add_schema(stream, schema=schema) + # Open & Parse file + stream = dom_parse(f, filename=src_file.name) + + # Configure stream + stream = stream_functions.strip_text(stream) + stream = stream_functions.add_context(stream) + stream = stream_functions.add_schema(stream, schema=schema) + stream = filters.add_column_spec(stream, schema_path=schema_path) + + # Clean stream + stream = stream_functions.log_blanks(stream) + stream = stream_functions.conform_cell_types(stream) + stream = stream_functions.validate_elements(stream) + + # Create dataset + error_holder, stream = stream_functions.collect_errors(stream) + stream = stream_record.message_collector(stream) + dataset_holder, stream = stream_record.export_table(stream) - stream = filters.convert_true_false(stream) + # Consume stream so we know it's been processed + generic.consume(stream) - stream = filters.validate_elements(stream) - stream = filters.remove_invalid(stream) + dataset = dataset_holder.value + errors = error_holder.value - stream = stream_record.message_collector(stream) - data = stream_record.export_table(stream) + dataset = DataContainer({k: pd.DataFrame(v) for k, v in dataset.items()}) - return data + return ProcessResult(data=dataset, errors=errors) diff --git a/liiatools/cin_census_pipeline/stream_record.py b/liiatools/cin_census_pipeline/stream_record.py index 92c09f12..02bc1f9e 100644 --- a/liiatools/cin_census_pipeline/stream_record.py +++ b/liiatools/cin_census_pipeline/stream_record.py @@ -1,65 +1,30 @@ +from more_itertools import peekable from typing import Iterator -import tablib -from more_itertools import peekable -from sfdata_stream_parser import events from sfdata_stream_parser.collectors import xml_collector +from sfdata_stream_parser import events +from sfdata_stream_parser.filters.generic import generator_with_value +from liiatools.common.stream_record import text_collector, HeaderEvent, _reduce_dict -class CINEvent(events.ParseEvent): - pass - -class HeaderEvent(events.ParseEvent): +class CINEvent(events.ParseEvent): + @staticmethod + def name(): + return "CIN" + pass -def _reduce_dict(dict_instance): +@xml_collector +def cin_collector(stream): """ - Reduces the values of a dictionary by unwrapping single-item lists. + Create a dictionary of text values for each CIN element; + Assessments, CINPlanDates, Section47 and ChildProtectionPlans - Parameters: - - dict_instance (dict): A dictionary where each key maps to a list. - - Returns: - - dict: A new dictionary where single-item lists are unwrapped to their sole element. - - Behavior: - - Iterates through each (key, value) pair in the input dictionary. - - If the value is a list with a single element, the function replaces the list with that element. - - Otherwise, the value is left as is. - - Examples: - >>> _reduce_dict({'a': [1], 'b': [2, 3], 'c': [4, 5, 6]}) - {'a': 1, 'b': [2, 3], 'c': [4, 5, 6]} - - >>> _reduce_dict({'x': ['single'], 'y': ['multi', 'elements']}) - {'x': 'single', 'y': ['multi', 'elements']} + :param stream: An iterator of events from an XML parser + :return: Dictionary containing element name and text values """ - new_dict = {} - for key, value in dict_instance.items(): - if len(value) == 1: - new_dict[key] = value[0] - else: - new_dict[key] = value - return new_dict - - -@xml_collector -def text_collector(stream): - data_dict = {} - current_element = None - for event in stream: - if isinstance(event, events.StartElement): - current_element = event.tag - if isinstance(event, events.TextNode) and event.text: - data_dict.setdefault(current_element, []).append(event.text) - - return _reduce_dict(data_dict) - - -@xml_collector -def cin_collector(stream): data_dict = {} stream = peekable(stream) last_tag = None @@ -74,8 +39,8 @@ def cin_collector(stream): ): data_dict.setdefault(event.tag, []).append(text_collector(stream)) else: - if isinstance(event, events.TextNode) and event.text: - data_dict.setdefault(last_tag, []).append(event.text) + if isinstance(event, events.TextNode) and event.cell: + data_dict.setdefault(last_tag, []).append(event.cell) next(stream) return _reduce_dict(data_dict) @@ -83,6 +48,12 @@ def cin_collector(stream): @xml_collector def child_collector(stream): + """ + Create a dictionary of text values for each Child element; ChildIdentifiers, ChildCharacteristics and CINdetails + + :param stream: An iterator of events from an XML parser + :return: Dictionary containing element name and text values + """ data_dict = {} stream = peekable(stream) assert stream.peek().tag == "Child" @@ -98,8 +69,13 @@ def child_collector(stream): return _reduce_dict(data_dict) -@xml_collector def message_collector(stream): + """ + Collect messages from XML elements and yield events + + :param stream: An iterator of events from an XML parser + :yield: Events of type HeaderEvent or CINEvent + """ stream = peekable(stream) assert stream.peek().tag == "Message", "Expected Message, got {}".format( stream.peek().tag @@ -150,7 +126,6 @@ def message_collector(stream): "ExpectedPersonBirthDate", "GenderCurrent", "PersonDeathDate", - "PersonSchoolYear", "Ethnicity", "Disabilities", ] @@ -223,7 +198,7 @@ def cin_event(record, property, event_name=None): value = record.get(property) if value: new_record = {**record, "Date": value, "Type": event_name} - return ({k: new_record.get(k) for k in __EXPORT_HEADERS},) + return {k: new_record.get(k) for k in __EXPORT_HEADERS}, return () @@ -291,10 +266,21 @@ def event_to_records(event: CINEvent) -> Iterator[dict]: ) +@generator_with_value def export_table(stream): - data = tablib.Dataset(headers=__EXPORT_HEADERS) + """ + Collects all the records into a dictionary of lists of rows + + This filter requires that the stream has been processed by `message_collector` first + + :param stream: An iterator of events from message_collector + :yield: All events + :return: A dictionary of lists of rows, keyed by record name + """ + dataset = {} for event in stream: - if isinstance(event, CINEvent): - for record in event_to_records(event): - data.append([record.get(k, "") for k in __EXPORT_HEADERS]) - return data + event_type = type(event) + for record in event_to_records(event): + dataset.setdefault(event_type.name(), []).append(record) + yield event + return dataset diff --git a/liiatools/common/_transform_functions.py b/liiatools/common/_transform_functions.py index f134caa9..943524c4 100644 --- a/liiatools/common/_transform_functions.py +++ b/liiatools/common/_transform_functions.py @@ -41,12 +41,24 @@ def add_quarter(row: pd.Series, column_config: ColumnConfig, metadata: Metadata) return metadata["quarter"] +def add_school_year(row: pd.Series, column_config: ColumnConfig, metadata: Metadata) -> str: + date_value = row["PersonBirthDate"] + if date_value.month >= 9: + school_year = date_value.year + elif date_value.month <= 8: + school_year = date_value.year - 1 + else: + school_year = None + return school_year + + enrich_functions = { "add_la_suffix": add_la_suffix, "la_code": add_la_code, "la_name": add_la_name, "year": add_year, "quarter": add_quarter, + "school_year": add_school_year, } diff --git a/liiatools/common/stream_filters.py b/liiatools/common/stream_filters.py index f1180845..8b400792 100644 --- a/liiatools/common/stream_filters.py +++ b/liiatools/common/stream_filters.py @@ -1,11 +1,13 @@ import logging -from io import BytesIO, StringIO -from typing import Iterable, Union, Any, Dict - +import xmlschema import tablib -from sfdata_stream_parser import events, collectors +import xml.etree.ElementTree as ET +from io import BytesIO, StringIO +from typing import Iterable, Union, Any, Dict, List +from pathlib import Path from tablib import import_book, import_set +from sfdata_stream_parser import events, collectors from sfdata_stream_parser.checks import type_check from sfdata_stream_parser.filters.generic import ( generator_with_value, @@ -23,7 +25,7 @@ to_category, ) -from .spec.__data_schema import Column, DataSchema +from .spec.__data_schema import Column, DataSchema, Numeric, Category logger = logging.getLogger(__name__) @@ -397,3 +399,209 @@ def collect_errors(stream): # With the stream fully consumed, we can return the collected errors return collected_errors + + +@streamfilter(check=type_check(events.TextNode), fail_function=pass_event) +def strip_text(event): + """ + Strips surrounding whitespaces from :class:`sfdata_stream_parser.events.TextNode`. If the event does + not have a text property then this filter fails silently. + + :param event: A filtered list of event objects + :return: Event with whitespace striped + """ + if not hasattr(event, "cell"): + return event + + if event.cell is None: + return event + + cell = event.cell.strip() + if len(cell) > 0: + return event.from_event(event, cell=cell) + else: + return None + + +@streamfilter(default_args=lambda: {"context": []}) +def add_context(event, context: List[str]): + """ + Adds 'context' to XML structures. For each :class:`sfdata_stream_parser.events.StartElement` the tag name is + added to a 'context' tuple, and for each :class:`sfdata_stream_parser.events.EndElement` the context is popped. + + For all other events, the context tuple is set as-is. + + :param event: A filtered list of event objects + :param context: A list to be populated with context information + :return: Event with context + """ + if isinstance(event, events.StartElement): + context.append(event.tag) + local_context = tuple(context) + elif isinstance(event, events.EndElement): + local_context = tuple(context) + context.pop() + else: + local_context = tuple(context) + + return event.from_event(event, context=local_context) + + +@streamfilter() +def add_schema(event, schema: xmlschema.XMLSchema): + """ + Requires each event to have event.context as set by :func:`add_context` + + Based on the context (a tuple of element tags) it will set path which is the + derived path (based on the context tags) joined by '/' and schema holding the + corresponding schema element, if found. + + :param event: A filtered list of event objects + :param schema: The xml schema to be attached to a given event + :return: Event with path, schema and header attributes + """ + assert ( + event.context + ), "This filter required event.context to be set - see add_context" + path = "/".join(event.context) + tag = event.context[-1] + el = schema.get_element(tag, path) + header = getattr(el, "name", None) + return event.from_event(event, path=path, schema=el, header=header) + + +def _get_validation_error(event, schema, node): + """ + Validate an event + + :param event: A filtered list of event objects + :param schema: The xml schema attached to a given event + :param node: The node attached to a given event + :return: Error information + """ + try: + validation_error_iterator = schema.iter_errors(node) + for validation_error in validation_error_iterator: + if " expected" in validation_error.reason: + raise ValueError( + f"Missing required field: '{validation_error.particle.name}' which occurs in the node starting on " + f"line: {validation_error.sourceline}" + ) + + except AttributeError: # Raised for nodes that don't exist in the schema + raise ValueError(f"Unexpected node '{event.tag}'") + + +@streamfilter(check=type_check(events.StartElement), fail_function=pass_event) +def validate_elements(event): + """ + Validates each element, and if not valid raises ValidationError: + + :param event: A filtered list of event objects + :return: Event if valid or event and error message if invalid + """ + # Only validate root element and elements with no schema + if isinstance(event, events.StartElement) and (event.node.getparent() is None or event.schema is None): + try: + _get_validation_error(event, event.schema, event.node) + return event + except ValueError as e: + return EventErrors.add_to_event( + event, type="ValidationError", message="Invalid node", exception=str(e) + ) + else: + return event + + +def _create_category_spec(field: str, file: Path) -> List[Category] | None: + """ + Create a list of Category classes containing the different categorical values of a given field to conform categories + e.g. [Category(code='0', name='Not an Agency Worker'), Category(code='1', name='Agency Worker')] + + :param field: Name of the categorical field you want to find the values for + :param file: Path to the .xsd schema containing possible categories + :return: List of Category classes of categorical values and potential alternatives + """ + category_spec = [] + + xsd_xml = ET.parse(file) + search_elem = f".//{{http://www.w3.org/2001/XMLSchema}}simpleType[@name='{field}']" + element = xsd_xml.find(search_elem) + + if element is not None: + search_value = f".//{{http://www.w3.org/2001/XMLSchema}}enumeration" # Find the 'code' parameter + value = element.findall(search_value) + if value: + for v in value: + category_spec.append(Category(code=v.get("value"))) + + search_doc = f".//{{http://www.w3.org/2001/XMLSchema}}documentation" # Find the 'name' parameter + documentation = element.findall(search_doc) + for i, d in enumerate(documentation): + category_spec[i].name = d.text + return category_spec + else: + return + + +def _create_numeric_spec(field: str, file: Path) -> Numeric: + """ + Create a Numeric class containing the different numeric parameters of a given field to conform numbers + e.g. Numeric(type='float', min_value=0, max_value=1, decimal_places=6) + + :param field: Name of the numeric field you want to find the parameters for + :param file: Path to the .xsd schema containing possible numeric parameters + :return: Numeric class of numeric parameters + """ + numeric_spec = None + + xsd_xml = ET.parse(file) + search_elem = f".//{{http://www.w3.org/2001/XMLSchema}}simpleType[@name='{field}']" + element = xsd_xml.find(search_elem) + + search_restriction = f".//{{http://www.w3.org/2001/XMLSchema}}restriction" # Find the 'type' parameter + restriction = element.findall(search_restriction) + for r in restriction: + if r.get("base")[3:] == "decimal": + numeric_spec = Numeric(type="float") + elif r.get("base")[3:] == "integer": + numeric_spec = Numeric(type="integer") + + search_fraction_digits = f".//{{http://www.w3.org/2001/XMLSchema}}fractionDigits" # Find the 'decimal' parameter + fraction_digits = element.findall(search_fraction_digits) + for f in fraction_digits: + numeric_spec.decimal_places = int(f.get("value")) + + search_min_inclusive = f".//{{http://www.w3.org/2001/XMLSchema}}minInclusive" # Find the 'min_value' parameter + min_inclusive = element.findall(search_min_inclusive) + for m in min_inclusive: + numeric_spec.min_value = int(m.get("value")) + + search_max_inclusive = f".//{{http://www.w3.org/2001/XMLSchema}}maxInclusive" # Find the 'max_value' parameter + max_inclusive = element.findall(search_max_inclusive) + for m in max_inclusive: + numeric_spec.max_value = int(m.get("value")) + + return numeric_spec + + +def _create_regex_spec(field: str, file: Path) -> str | None: + """ + Parse an XML file and extract the regex pattern for a given field name + + :param field: The name of the field to look for in the XML file + :param file: The path to the XML file + :return: The regex pattern, or None if no pattern is found + """ + regex_spec = None + + xsd_xml = ET.parse(file) + search_elem = f".//{{http://www.w3.org/2001/XMLSchema}}simpleType[@name='{field}']" + element = xsd_xml.find(search_elem) + + search_pattern = f".//{{http://www.w3.org/2001/XMLSchema}}pattern" # Find the 'cell_regex' parameter + pattern = element.findall(search_pattern) + for p in pattern: + regex_spec = p.get("value") + + return regex_spec diff --git a/liiatools/csww_pipeline/stream_parse.py b/liiatools/common/stream_parse.py similarity index 100% rename from liiatools/csww_pipeline/stream_parse.py rename to liiatools/common/stream_parse.py diff --git a/liiatools/common/stream_record.py b/liiatools/common/stream_record.py new file mode 100644 index 00000000..1e5b4f24 --- /dev/null +++ b/liiatools/common/stream_record.py @@ -0,0 +1,65 @@ +from sfdata_stream_parser import events +from sfdata_stream_parser.collectors import xml_collector +from sfdata_stream_parser.filters.generic import generator_with_value + + +class HeaderEvent(events.ParseEvent): + @staticmethod + def name(): + return "Header" + + pass + + +def _reduce_dict(dict_instance): + """ + Reduce lists in dictionary values to a single value if there is only one value in the list, + otherwise return the list + + :param dict_instance: Dictionary containing lists in the values + :return: Dictionary with single values in the dictionary values if list length is one + """ + new_dict = {} + for key, value in dict_instance.items(): + if len(value) == 1: + new_dict[key] = value[0] + else: + new_dict[key] = value + return new_dict + + +@xml_collector +def text_collector(stream): + """ + Create a dictionary of text values for each element + + :param stream: An iterator of events from an XML parser + :return: Dictionary containing element name and text values + """ + data_dict = {} + current_element = None + for event in stream: + if isinstance(event, events.StartElement): + current_element = event.tag + if isinstance(event, events.TextNode) and event.cell: + data_dict.setdefault(current_element, []).append(event.cell) + return _reduce_dict(data_dict) + + +@generator_with_value +def export_table(stream): + """ + Collects all the records into a dictionary of lists of rows + + This filter requires that the stream has been processed by `message_collector` first + + :param stream: An iterator of events from message_collector + :yield: All events + :return: A dictionary of lists of rows, keyed by record name + """ + dataset = {} + for event in stream: + event_type = type(event) + dataset.setdefault(event_type.name(), []).append(event.as_dict()["record"]) + yield event + return dataset diff --git a/liiatools/csww_pipeline/stream_filters.py b/liiatools/csww_pipeline/stream_filters.py index 7a27afe4..baec231f 100644 --- a/liiatools/csww_pipeline/stream_filters.py +++ b/liiatools/csww_pipeline/stream_filters.py @@ -1,223 +1,16 @@ import logging -from typing import List from pathlib import Path -import xml.etree.ElementTree as ET -import xmlschema from sfdata_stream_parser import events from sfdata_stream_parser.checks import type_check from sfdata_stream_parser.filters.generic import pass_event, streamfilter -from liiatools.common.stream_errors import EventErrors -from liiatools.common.spec.__data_schema import Column, Numeric, Category +from liiatools.common.spec.__data_schema import Column, Numeric +from liiatools.common.stream_filters import _create_category_spec, _create_numeric_spec, _create_regex_spec logger = logging.getLogger(__name__) -@streamfilter(check=type_check(events.TextNode), fail_function=pass_event) -def strip_text(event): - """ - Strips surrounding whitespaces from :class:`sfdata_stream_parser.events.TextNode`. If the event does - not have a text property then this filter fails silently. - - :param event: A filtered list of event objects - :return: Event with whitespace striped - """ - if not hasattr(event, "cell"): - return event - - if event.cell is None: - return event - - cell = event.cell.strip() - if len(cell) > 0: - return event.from_event(event, cell=cell) - else: - return None - - -@streamfilter(default_args=lambda: {"context": []}) -def add_context(event, context: List[str]): - """ - Adds 'context' to XML structures. For each :class:`sfdata_stream_parser.events.StartElement` the tag name is - added to a 'context' tuple, and for each :class:`sfdata_stream_parser.events.EndElement` the context is popped. - - For all other events, the context tuple is set as-is. - - :param event: A filtered list of event objects - :param context: A list to be populated with context information - :return: Event with context - """ - if isinstance(event, events.StartElement): - context.append(event.tag) - local_context = tuple(context) - elif isinstance(event, events.EndElement): - local_context = tuple(context) - context.pop() - else: - local_context = tuple(context) - - return event.from_event(event, context=local_context) - - -@streamfilter() -def add_schema(event, schema: xmlschema.XMLSchema): - """ - Requires each event to have event.context as set by :func:`add_context` - - Based on the context (a tuple of element tags) it will set path which is the - derived path (based on the context tags) joined by '/' and schema holding the - corresponding schema element, if found. - - :param event: A filtered list of event objects - :param schema: The xml schema to be attached to a given event - :return: Event with path, schema and header attributes - """ - assert ( - event.context - ), "This filter required event.context to be set - see add_context" - path = "/".join(event.context) - tag = event.context[-1] - el = schema.get_element(tag, path) - header = getattr(el, "name", None) - return event.from_event(event, path=path, schema=el, header=header) - - -def _get_validation_error(event, schema, node): - """ - Validate an event - - :param event: A filtered list of event objects - :param schema: The xml schema attached to a given event - :param node: The node attached to a given event - :return: Error information - """ - try: - validation_error_iterator = schema.iter_errors(node) - for validation_error in validation_error_iterator: - if " expected" in validation_error.reason: - raise ValueError( - f"Missing required field: '{validation_error.particle.name}' which occurs in the node starting on " - f"line: {validation_error.sourceline}" - ) - - except AttributeError: # Raised for nodes that don't exist in the schema - raise ValueError(f"Unexpected node '{event.tag}'") - - -@streamfilter(check=type_check(events.StartElement), fail_function=pass_event) -def validate_elements(event): - """ - Validates each element, and if not valid raises ValidationError: - - :param event: A filtered list of event objects - :return: Event if valid or event and error message if invalid - """ - # Only validate root element and elements with no schema - if isinstance(event, events.StartElement) and (event.node.getparent() is None or event.schema is None): - try: - _get_validation_error(event, event.schema, event.node) - return event - except ValueError as e: - return EventErrors.add_to_event( - event, type="ValidationError", message=f"Invalid node", exception=str(e) - ) - else: - return event - - -def _create_category_spec(field: str, file: Path) -> List[Category] | None: - """ - Create a list of Category classes containing the different categorical values of a given field to conform categories - e.g. [Category(code='0', name='Not an Agency Worker'), Category(code='1', name='Agency Worker')] - - :param field: Name of the categorical field you want to find the values for - :param file: Path to the .xsd schema containing possible categories - :return: List of Category classes of categorical values and potential alternatives - """ - category_spec = [] - - xsd_xml = ET.parse(file) - search_elem = f".//{{http://www.w3.org/2001/XMLSchema}}simpleType[@name='{field}']" - element = xsd_xml.find(search_elem) - - if element is not None: - search_value = f".//{{http://www.w3.org/2001/XMLSchema}}enumeration" # Find the 'code' parameter - value = element.findall(search_value) - if value: - for v in value: - category_spec.append(Category(code=v.get("value"))) - - search_doc = f".//{{http://www.w3.org/2001/XMLSchema}}documentation" # Find the 'name' parameter - documentation = element.findall(search_doc) - for i, d in enumerate(documentation): - category_spec[i].name = d.text - return category_spec - else: - return - - -def _create_numeric_spec(field: str, file: Path) -> Numeric: - """ - Create a Numeric class containing the different numeric parameters of a given field to conform numbers - e.g. Numeric(type='float', min_value=0, max_value=1, decimal_places=6) - - :param field: Name of the numeric field you want to find the parameters for - :param file: Path to the .xsd schema containing possible numeric parameters - :return: Numeric class of numeric parameters - """ - numeric_spec = None - - xsd_xml = ET.parse(file) - search_elem = f".//{{http://www.w3.org/2001/XMLSchema}}simpleType[@name='{field}']" - element = xsd_xml.find(search_elem) - - search_restriction = f".//{{http://www.w3.org/2001/XMLSchema}}restriction" # Find the 'type' parameter - restriction = element.findall(search_restriction) - for r in restriction: - if r.get("base")[3:] == "decimal": - numeric_spec = Numeric(type="float") - - search_fraction_digits = f".//{{http://www.w3.org/2001/XMLSchema}}fractionDigits" # Find the 'decimal' parameter - fraction_digits = element.findall(search_fraction_digits) - for f in fraction_digits: - numeric_spec.decimal_places = int(f.get("value")) - - search_min_inclusive = f".//{{http://www.w3.org/2001/XMLSchema}}minInclusive" # Find the 'min_value' parameter - min_inclusive = element.findall(search_min_inclusive) - for m in min_inclusive: - numeric_spec.min_value = int(m.get("value")) - - search_max_inclusive = f".//{{http://www.w3.org/2001/XMLSchema}}maxInclusive" # Find the 'max_value' parameter - max_inclusive = element.findall(search_max_inclusive) - for m in max_inclusive: - numeric_spec.max_value = int(m.get("value")) - - return numeric_spec - - -def _create_regex_spec(field: str, file: Path) -> str | None: - """ - Parse an XML file and extract the regex pattern for a given field name - - :param field: The name of the field to look for in the XML file - :param file: The path to the XML file - :return: The regex pattern, or None if no pattern is found - """ - regex_spec = None - - xsd_xml = ET.parse(file) - search_elem = f".//{{http://www.w3.org/2001/XMLSchema}}simpleType[@name='{field}']" - element = xsd_xml.find(search_elem) - - search_pattern = f".//{{http://www.w3.org/2001/XMLSchema}}pattern" # Find the 'cell_regex' parameter - pattern = element.findall(search_pattern) - for p in pattern: - regex_spec = p.get("value") - - return regex_spec - - @streamfilter( check=type_check(events.TextNode), fail_function=pass_event, diff --git a/liiatools/csww_pipeline/stream_pipeline.py b/liiatools/csww_pipeline/stream_pipeline.py index bc78e3e2..24059f55 100644 --- a/liiatools/csww_pipeline/stream_pipeline.py +++ b/liiatools/csww_pipeline/stream_pipeline.py @@ -1,13 +1,15 @@ -import pandas as pd -from xmlschema import XMLSchema from pathlib import Path +from xmlschema import XMLSchema +import pandas as pd from sfdata_stream_parser.filters import generic from liiatools.common.data import FileLocator, ProcessResult, DataContainer from liiatools.common import stream_filters as stream_functions +from liiatools.common.stream_parse import dom_parse +from liiatools.common.stream_record import export_table + from liiatools.csww_pipeline import stream_record -from liiatools.csww_pipeline.stream_parse import dom_parse from . import stream_filters as filters @@ -27,20 +29,20 @@ def task_cleanfile( stream = dom_parse(f, filename=src_file.name) # Configure stream - stream = filters.strip_text(stream) - stream = filters.add_context(stream) - stream = filters.add_schema(stream, schema=schema) + stream = stream_functions.strip_text(stream) + stream = stream_functions.add_context(stream) + stream = stream_functions.add_schema(stream, schema=schema) stream = filters.add_column_spec(stream, schema_path=schema_path) # Clean stream stream = stream_functions.log_blanks(stream) stream = stream_functions.conform_cell_types(stream) - stream = filters.validate_elements(stream) + stream = stream_functions.validate_elements(stream) # Create dataset error_holder, stream = stream_functions.collect_errors(stream) stream = stream_record.message_collector(stream) - dataset_holder, stream = stream_record.export_table(stream) + dataset_holder, stream = export_table(stream) # Consume stream so we know it's been processed generic.consume(stream) diff --git a/liiatools/csww_pipeline/stream_record.py b/liiatools/csww_pipeline/stream_record.py index a91be1ff..6b5bb2a5 100644 --- a/liiatools/csww_pipeline/stream_record.py +++ b/liiatools/csww_pipeline/stream_record.py @@ -1,8 +1,8 @@ from more_itertools import peekable from sfdata_stream_parser import events -from sfdata_stream_parser.collectors import xml_collector -from sfdata_stream_parser.filters.generic import generator_with_value + +from liiatools.common.stream_record import text_collector, HeaderEvent class CSWWEvent(events.ParseEvent): @@ -21,49 +21,6 @@ def name(): pass -class HeaderEvent(events.ParseEvent): - @staticmethod - def name(): - return "Header" - - pass - - -def _reduce_dict(dict_instance): - """ - Reduce lists in dictionary values to a single value if there is only one value in the list, - otherwise return the list - - :param dict_instance: Dictionary containing lists in the values - :return: Dictionary with single values in the dictionary values if list length is one - """ - new_dict = {} - for key, value in dict_instance.items(): - if len(value) == 1: - new_dict[key] = value[0] - else: - new_dict[key] = value - return new_dict - - -@xml_collector -def text_collector(stream): - """ - Create a dictionary of text values for each element - - :param stream: An iterator of events from an XML parser - :return: Dictionary containing element name and text values - """ - data_dict = {} - current_element = None - for event in stream: - if isinstance(event, events.StartElement): - current_element = event.tag - if isinstance(event, events.TextNode) and event.cell: - data_dict.setdefault(current_element, []).append(event.cell) - return _reduce_dict(data_dict) - - def message_collector(stream): """ Collect messages from XML elements and yield events @@ -89,22 +46,3 @@ def message_collector(stream): yield LALevelEvent(record=lalevel_record) else: next(stream) - - -@generator_with_value -def export_table(stream): - """ - Collects all the records into a dictionary of lists of rows - - This filter requires that the stream has been processed by `message_collector` first - - :param stream: An iterator of events from message_collector - :yield: All events - :return: A dictionary of lists of rows, keyed by record name - """ - dataset = {} - for event in stream: - event_type = type(event) - dataset.setdefault(event_type.name(), []).append(event.as_dict()["record"]) - yield event - return dataset diff --git a/tests/cin_census/test_schema.py b/tests/cin_census/test_config.py similarity index 100% rename from tests/cin_census/test_schema.py rename to tests/cin_census/test_config.py diff --git a/tests/cin_census/test_converter.py b/tests/cin_census/test_converter.py deleted file mode 100644 index 4bd2a567..00000000 --- a/tests/cin_census/test_converter.py +++ /dev/null @@ -1,40 +0,0 @@ -# from liiatools.datasets.cin_census.lds_cin_clean import converter -# from sfdata_stream_parser import events -# -# -# def test_convert_true_false(): -# class Schema: -# def __init__(self): -# self.type = Name() -# -# class Name: -# def __init__(self): -# self.name = "yesnotype" -# -# stream = converter.convert_true_false( -# [ -# events.TextNode(text="false", schema=Schema()), -# events.TextNode(text="true", schema=Schema()), -# events.TextNode(text="TRUE", schema=Schema()), -# ] -# ) -# stream = list(stream) -# assert stream[0].text == "0" -# assert stream[1].text == "1" -# assert stream[2].text == "1" -# -# class Name: -# def __init__(self): -# self.name = "other_type" -# -# stream = converter.convert_true_false( -# [ -# events.TextNode(text="false", schema=Schema()), -# events.TextNode(text="true", schema=Schema()), -# events.TextNode(text="true"), -# ] -# ) -# stream = list(stream) -# assert stream[0].text == "false" -# assert stream[1].text == "true" -# assert stream[2].text == "true" diff --git a/tests/cin_census/test_end_to_end.py b/tests/cin_census/test_end_to_end.py index 04bcc9e1..cbe35741 100644 --- a/tests/cin_census/test_end_to_end.py +++ b/tests/cin_census/test_end_to_end.py @@ -7,6 +7,7 @@ import liiatools from liiatools.__main__ import cli +from liiatools.cin_census_pipeline.spec.samples import CIN_2022 @pytest.fixture(scope="session", autouse=True) @@ -32,7 +33,35 @@ def log_dir(build_dir): @pytest.mark.skipif(os.environ.get("SKIP_E2E"), reason="Skipping end-to-end tests") -def test_end_to_end(liiatools_dir, build_dir, log_dir): +def test_end_to_end(liiatools_dir, build_dir): + incoming_dir = build_dir / "incoming" + incoming_dir.mkdir(parents=True, exist_ok=True) + pipeline_dir = build_dir / "pipeline" + pipeline_dir.mkdir(parents=True, exist_ok=True) + + shutil.copy(CIN_2022, incoming_dir / f"cin-2022.xml") + + runner = CliRunner() + result = runner.invoke( + cli, + [ + "cin-census", + "pipeline", + "-c", + "BAD", + "--input", + incoming_dir.as_posix(), + "--output", + pipeline_dir.as_posix(), + ], + catch_exceptions=False, + ) + + assert result.exit_code == 0 + + +@pytest.mark.skip("Old pipeline") +def test_end_to_end_old(liiatools_dir, build_dir, log_dir): runner = CliRunner() result = runner.invoke( cli, diff --git a/tests/cin_census/test_file_creator.py b/tests/cin_census/test_file_creator.py deleted file mode 100644 index 9cf7ab1e..00000000 --- a/tests/cin_census/test_file_creator.py +++ /dev/null @@ -1,17 +0,0 @@ -# from liiatools.datasets.cin_census.lds_cin_clean import file_creator -# -# import pandas as pd -# from datetime import datetime -# -# -# def test_get_year(): -# year = "2022" -# -# data = { -# "CHILD ID": [123, 456], -# "DOB": [datetime(2019, 4, 15).date(), datetime(2015, 7, 19).date()], -# } -# data = pd.DataFrame(data=data) -# -# stream = file_creator.get_year(data, year) -# assert stream["YEAR"].tolist() == ["2022", "2022"] diff --git a/tests/cin_census/test_reports.py b/tests/cin_census/test_reports.py index 9403d2c0..6c6d2634 100644 --- a/tests/cin_census/test_reports.py +++ b/tests/cin_census/test_reports.py @@ -1,6 +1,9 @@ import pandas as pd -from liiatools.cin_census_pipeline.reports import expanded_assessment_factors +from liiatools.cin_census_pipeline.reports import ( + expanded_assessment_factors, + referral_outcomes +) def test_assessment_factors(): diff --git a/tests/cin_census/test_stream_pipeline.py b/tests/cin_census/test_stream_pipeline.py index 02c2c149..8d786757 100644 --- a/tests/cin_census/test_stream_pipeline.py +++ b/tests/cin_census/test_stream_pipeline.py @@ -1,6 +1,8 @@ from fs import open_fs +import xml.etree.ElementTree as ET +import os -from liiatools.cin_census_pipeline.spec import load_schema +from liiatools.cin_census_pipeline.spec import load_schema, load_schema_path from liiatools.cin_census_pipeline.spec.samples import CIN_2022 from liiatools.cin_census_pipeline.spec.samples import DIR as SAMPLES_DIR from liiatools.cin_census_pipeline.stream_pipeline import task_cleanfile @@ -11,21 +13,46 @@ def test_task_cleanfile(): samples_fs = open_fs(SAMPLES_DIR.as_posix()) locator = FileLocator(samples_fs, CIN_2022.name) - data = task_cleanfile(locator, schema=load_schema(2022)) + result = task_cleanfile(locator, schema=load_schema(2022), schema_path=load_schema_path(2022)) - assert len(data) == 10 - assert len(data.headers) == 34 + data = result.data + errors = result.errors + assert len(data) == 1 + assert len(data["CIN"]) == 10 + assert len(data["CIN"].columns) == 33 -if __name__ == "__main__": - from pathlib import Path + assert len(errors) == 0 + + +def test_task_cleanfile_error(): + tree = ET.parse(CIN_2022) + root = tree.getroot() + + parent = root.find(".//Source") + el = parent.find("DateTime") + el.text = el.text.replace("2022-05-23T11:14:05", "not_date") + + tree.write(SAMPLES_DIR / 'cin_2022_error.xml') samples_fs = open_fs(SAMPLES_DIR.as_posix()) - locator = FileLocator(samples_fs, CIN_2022.name) + locator = FileLocator(samples_fs, "cin_2022_error.xml") + + result = task_cleanfile( + locator, schema=load_schema(2022), schema_path=load_schema_path(2022) + ) + + data = result.data + errors = result.errors - data = task_cleanfile(locator, schema=load_schema(2022)) + assert len(data) == 1 + assert len(data["CIN"]) == 10 + assert len(data["CIN"].columns) == 33 - build_dir = Path(__file__).parent.parent.parent / "build" - build_dir.mkdir(exist_ok=True) + assert errors[0]["type"] == "ConversionError" + assert errors[0]["message"] == "Could not convert to date" + assert errors[0]["exception"] == "Invalid date: not_date" + assert errors[0]["filename"] == "cin_2022_error.xml" + assert errors[0]["header"] == "DateTime" - (build_dir / "cin_census_2022_flatfile.csv").write_text(data.export("csv")) + os.remove(SAMPLES_DIR / "cin_2022_error.xml") diff --git a/tests/cin_census/test_validator.py b/tests/cin_census/test_validator.py deleted file mode 100644 index 9a8a1091..00000000 --- a/tests/cin_census/test_validator.py +++ /dev/null @@ -1,134 +0,0 @@ -import sys -import xml.etree.ElementTree as ET -from io import BytesIO -from typing import Iterable - -import yaml -from sfdata_stream_parser.events import ParseEvent, StartElement -from sfdata_stream_parser.parser.xml import parse -from xmlschema.validators.exceptions import XMLSchemaValidatorError -from xmlschema.validators.facets import XsdMinLengthFacet -from xmlschema.validators.groups import XsdGroup - -from liiatools.cin_census_pipeline import stream_filters as filters -from liiatools.cin_census_pipeline.spec import load_schema -from liiatools.cin_census_pipeline.spec.samples import CIN_2022 -from liiatools.cin_census_pipeline.spec.samples import DIR as SAMPLES_DIR -from liiatools.cin_census_pipeline.stream_pipeline import task_cleanfile -from liiatools.cin_census_pipeline.stream_parse import dom_parse - - -def _xml_to_stream(root) -> Iterable[ParseEvent]: - schema = load_schema(2022) - - input = BytesIO(ET.tostring(root, encoding="utf-8")) - stream = dom_parse(input) - stream = filters.strip_text(stream) - stream = filters.add_context(stream) - stream = filters.add_schema(stream, schema=schema) - stream = filters.validate_elements(stream) - return list(stream) - - -def test_validate_all_valid(): - with CIN_2022.open("rb") as f: - root = ET.parse(f).getroot() - - stream = _xml_to_stream(root) - - # Count nodes in stream that are not valid - invalid_nodes = [e for e in stream if not getattr(e, "valid", True)] - assert len(invalid_nodes) == 0 - - -def test_validate_missing_child_id(): - with CIN_2022.open("rb") as f: - root = ET.parse(f).getroot() - - parent = root.find(".//ChildIdentifiers") - el = parent.find("LAchildID") - parent.remove(el) - - stream = _xml_to_stream(root) - - # Count nodes in stream that are not valid - invalid_nodes = [e for e in stream if not getattr(e, "valid", True)] - assert len(invalid_nodes) == 1 - - error: XMLSchemaValidatorError = invalid_nodes[0].validation_errors[0] - assert type(error.validator) == XsdGroup - assert error.particle.name == "LAchildID" - assert error.occurs == 0 - assert error.sourceline == 19 - - -def test_validate_blank_child_id(): - with CIN_2022.open("rb") as f: - root = ET.parse(f).getroot() - - el = root.find(".//LAchildID") - el.text = "" - stream = _xml_to_stream(root) - - # Count nodes in stream that are not valid - invalid_nodes = [e for e in stream if not getattr(e, "valid", True)] - assert len(invalid_nodes) == 1 - - error: XMLSchemaValidatorError = invalid_nodes[0].validation_errors[0] - assert type(error.validator) == XsdMinLengthFacet - assert error.reason == "value length cannot be lesser than 1" - assert error.sourceline == 20 - - -def test_validate_reordered_child_id(): - with CIN_2022.open("rb") as f: - root = ET.parse(f).getroot() - - el_parent = root.find(".//LAchildID/..") - el_child_id = el_parent.find("LAchildID") - el_parent.remove(el_child_id) - el_parent.append(el_child_id) - - xml = ET.tostring(el_parent, encoding="utf-8") - stream = _xml_to_stream(root) - - # Count nodes in stream that are not valid - invalid_nodes = [e for e in stream if not getattr(e, "valid", True)] - assert len(invalid_nodes) == 1 - - error: XMLSchemaValidatorError = invalid_nodes[0].validation_errors[0] - assert type(error.validator) == XsdGroup - assert error.particle.name == "LAchildID" - assert error.occurs == 0 - assert error.sourceline == 19 - - -class FakeLocator: - def __init__(self, data): - self.data = data - self.meta = {} - - def open(self, mode: str = "r"): - return BytesIO(self.data) - - -def test_remove_invalid(): - with CIN_2022.open("rb") as f: - root = ET.parse(f).getroot() - - xml_string = ET.tostring(root, encoding="utf-8") - locator = FakeLocator(xml_string) - - data = task_cleanfile(locator, schema=load_schema(2022)) - child_ids = set(data["LAchildID"]) - assert child_ids == {"DfEX0000001"} - - el = root.find(".//LAchildID") - el.text = "" - - xml_string = ET.tostring(root, encoding="utf-8") - locator = FakeLocator(xml_string) - - data = task_cleanfile(locator, schema=load_schema(2022)) - child_ids = set(data["LAchildID"]) - assert child_ids == {None} diff --git a/tests/common/test_filters.py b/tests/common/test_filters.py index b559a06e..b3928add 100644 --- a/tests/common/test_filters.py +++ b/tests/common/test_filters.py @@ -409,3 +409,14 @@ def test_clean_regex(): cleaned_event = list(stream_filters.conform_cell_types(event))[0] assert cleaned_event.cell == "" assert_errors(cleaned_event) + + regex_spec = Column(string="regex", cell_regex=r"[A-Za-z]\d{11}(\d|[A-Za-z])") + event = events.Cell(cell="A123456789012", column_spec=regex_spec) + cleaned_event = list(stream_filters.conform_cell_types(event))[0] + assert cleaned_event.cell == "A123456789012" + assert_errors(cleaned_event) + + event = events.Cell(cell="A12345678901B", column_spec=regex_spec) + cleaned_event = list(stream_filters.conform_cell_types(event))[0] + assert cleaned_event.cell == "A12345678901B" + assert_errors(cleaned_event) diff --git a/tests/social_work_workforce/test_parse.py b/tests/common/test_parse.py similarity index 98% rename from tests/social_work_workforce/test_parse.py rename to tests/common/test_parse.py index 4d17ef26..915f6e0f 100644 --- a/tests/social_work_workforce/test_parse.py +++ b/tests/common/test_parse.py @@ -1,6 +1,6 @@ from io import BytesIO -from liiatools.csww_pipeline.stream_parse import dom_parse +from liiatools.common.stream_parse import dom_parse from sfdata_stream_parser.events import ( StartElement, EndElement, diff --git a/tests/common/test_stream_filters.py b/tests/common/test_stream_filters.py index 32f24f9a..678f4f31 100644 --- a/tests/common/test_stream_filters.py +++ b/tests/common/test_stream_filters.py @@ -1,10 +1,34 @@ +import xml.etree.ElementTree as ET +from io import BytesIO +from typing import Iterable from fs import open_fs -from sfdata_stream_parser.events import StartContainer + + +from sfdata_stream_parser.events import StartElement, EndElement, TextNode, ParseEvent, StartContainer from liiatools.common.data import FileLocator from liiatools.common.stream_filters import tablib_parse +from liiatools.common.stream_parse import dom_parse +from liiatools.common.spec.__data_schema import ( + Numeric, + Category, +) +from liiatools.common.stream_filters import ( + strip_text, + add_context, + add_schema, + validate_elements, + _create_category_spec, + _create_numeric_spec, + _create_regex_spec, +) from liiatools.annex_a_pipeline.spec.samples import DIR as DIR_AA from liiatools.ssda903_pipeline.spec.samples import DIR as DIR_903 +from liiatools.csww_pipeline.spec.samples import CSWW_2022 +from liiatools.csww_pipeline.spec import ( + load_schema, + load_schema_path, +) def test_parse_tabular_csv(): @@ -40,3 +64,190 @@ def test_parse_with_alternative_name(): stream = list(stream) assert stream assert stream[0] == StartContainer(filename="/year/2020/episodes.csv") + + +def test_strip_text(): + stream = [ + TextNode(text=None), + TextNode(cell=None, text=None), + TextNode(cell="string", text=None), + TextNode(cell=" string_with_whitespace ", text=None), + ] + + stripped_stream = list(strip_text(stream)) + assert stream[0] == stripped_stream[0] + assert stream[1] == stripped_stream[1] + assert stripped_stream[2].cell == "string" + assert stripped_stream[3].cell == "string_with_whitespace" + + +def test_add_context(): + stream = [ + StartElement(tag="Message"), + StartElement(tag="Header"), + TextNode(cell="string", text=None), + EndElement(tag="Header"), + EndElement(tag="Message"), + ] + + context_stream = list(add_context(stream)) + assert context_stream[0].context == ("Message",) + assert context_stream[1].context == ("Message", "Header") + assert context_stream[2].context == ("Message", "Header") + assert context_stream[3].context == ("Message", "Header") + assert context_stream[4].context == ("Message",) + + +def test_add_schema(): + schema = load_schema(year=2022) + stream = [ + StartElement(tag="Message", context=("Message",)), + StartElement(tag="Header", context=("Message", "Header")), + TextNode(cell="string", text=None, context=("Message", "Header")), + EndElement(tag="Header", context=("Message", "Header")), + EndElement(tag="Message", context=("Message",)), + ] + + schema_stream = list(add_schema(stream, schema=schema)) + + assert schema_stream[0].schema.name == "Message" + assert schema_stream[0].schema.occurs == (1, 1) + assert schema_stream[1].schema.name == "Header" + assert schema_stream[1].schema.occurs == (0, 1) + assert schema_stream[2].schema.name == "Header" + assert schema_stream[2].schema.occurs == (0, 1) + assert schema_stream[3].schema.name == "Header" + assert schema_stream[3].schema.occurs == (0, 1) + assert schema_stream[4].schema.name == "Message" + assert schema_stream[4].schema.occurs == (1, 1) + + +def _xml_to_stream(root) -> Iterable[ParseEvent]: + schema = load_schema(2022) + + input = BytesIO(ET.tostring(root, encoding="utf-8")) + stream = dom_parse(input, filename="test.xml") + stream = strip_text(stream) + stream = add_context(stream) + stream = add_schema(stream, schema=schema) + stream = validate_elements(stream) + return list(stream) + + +def test_validate_all_valid(): + with CSWW_2022.open("rb") as f: + root = ET.parse(f).getroot() + + stream = _xml_to_stream(root) + + for event in stream: + assert not hasattr(event, "errors") + + +def test_validate_missing_required_field(): + with CSWW_2022.open("rb") as f: + root = ET.parse(f).getroot() + + parent = root.find(".//CSWWWorker") + el = parent.find("AgencyWorker") + parent.remove(el) + + stream = _xml_to_stream(root) + + errors = [] + for event in stream: + if hasattr(event, "errors"): + errors.append(event.errors) + + assert list(errors[0])[0] == { + "type": "ValidationError", + "message": "Invalid node", + "exception": "Missing required field: 'AgencyWorker' which occurs in the node starting on line: 20", + } + + +def test_validate_reordered_required_field(): + with CSWW_2022.open("rb") as f: + root = ET.parse(f).getroot() + + el_parent = root.find(".//AgencyWorker/..") + el_child_id = el_parent.find("AgencyWorker") + el_parent.remove(el_child_id) + el_parent.append(el_child_id) + + stream = _xml_to_stream(root) + + errors = [] + for event in stream: + if hasattr(event, "errors"): + errors.append(event.errors) + + assert list(errors[0])[0] == { + "type": "ValidationError", + "message": "Invalid node", + "exception": "Missing required field: 'AgencyWorker' which occurs in the node starting on line: 20", + } + + +def test_validate_unexpected_node(): + with CSWW_2022.open("rb") as f: + root = ET.parse(f).getroot() + + parent = root.find(".//CSWWWorker") + ET.SubElement(parent, "Unknown_Node") + + stream = _xml_to_stream(root) + + errors = [] + for event in stream: + if hasattr(event, "errors"): + errors.append(event.errors) + + assert list(errors[0])[0] == { + "type": "ValidationError", + "message": "Invalid node", + "exception": "Unexpected node 'Unknown_Node'", + } + + +def test_create_category_spec(): + schema_path = load_schema_path(2022) + field = "agencyworkertype" + category_spec = _create_category_spec(field, schema_path) + + assert category_spec == [ + Category( + code="0", + name="Not an Agency Worker", + cell_regex=None, + model_config={"extra": "forbid"}, + ), + Category( + code="1", + name="Agency Worker", + cell_regex=None, + model_config={"extra": "forbid"}, + ), + ] + + +def test_create_numeric_spec(): + schema_path = load_schema_path(2022) + field = "twodecimalplaces" + numeric_spec = _create_numeric_spec(field, schema_path) + + assert numeric_spec == Numeric( + type="float", + min_value=None, + max_value=None, + decimal_places=2, + model_config={"extra": "forbid"}, + ) + + +def test_create_regex_spec(): + schema_path = load_schema_path(2022) + field = "swetype" + regex_spec = _create_regex_spec(field, schema_path) + + assert regex_spec == r"[A-Za-z]{2}\d{10}" \ No newline at end of file diff --git a/tests/common/test_stream_record.py b/tests/common/test_stream_record.py new file mode 100644 index 00000000..fe157136 --- /dev/null +++ b/tests/common/test_stream_record.py @@ -0,0 +1,91 @@ +import unittest + +from sfdata_stream_parser.events import StartElement, EndElement, TextNode + +from liiatools.common.stream_record import ( + _reduce_dict, + text_collector, + export_table, +) + +from liiatools.csww_pipeline.stream_record import message_collector + + +def test_reduce_dict(): + sample_dict = { + "ID": ["100"], + "SWENo": ["AB123456789"], + "Agency": ["0"], + "ReasonAbsence": ["MAT", "OTH"], + } + assert _reduce_dict(sample_dict) == { + "ID": "100", + "SWENo": "AB123456789", + "Agency": "0", + "ReasonAbsence": ["MAT", "OTH"], + } + + +class TestRecord(unittest.TestCase): + def generate_text_element(self, tag: str, cell): + """ + Create a complete TextNode sandwiched between a StartElement and EndElement + + :param tag: XML tag + :param cell: text to be stored in the given XML tag, could be a string, integer, float etc. + :return: StartElement and EndElement with given tags and TextNode with given text + """ + yield StartElement(tag=tag) + yield TextNode(cell=str(cell), text=None) + yield EndElement(tag=tag) + + def generate_test_csww_file(self): + """ + Generate a sample children's social work workforce census file + + :return: stream of generators containing information required to create an XML file + """ + yield StartElement(tag="Message") + yield StartElement(tag="Header") + yield from self.generate_text_element(tag="Version", cell=1) + yield EndElement(tag="Header") + yield StartElement(tag="LALevelVacancies") + yield from self.generate_text_element(tag="NumberOfVacancies", cell=100) + yield EndElement(tag="LALevelVacancies") + yield StartElement(tag="CSWWWorker") + yield from self.generate_text_element(tag="ID", cell=100) + yield from self.generate_text_element(tag="SWENo", cell="AB123456789") + yield from self.generate_text_element(tag="Agency", cell=0) + yield EndElement(tag="CSWWWorker") + yield EndElement(tag="Message") + + def test_text_collector(self): + # test that the text_collector returns a dictionary of events and their text values from the stream + test_stream = self.generate_test_csww_file() + test_record = text_collector(test_stream) + self.assertEqual(len(test_record), 5) + self.assertEqual( + test_record, + { + "Version": "1", + "NumberOfVacancies": "100", + "ID": "100", + "SWENo": "AB123456789", + "Agency": "0", + }, + ) + + def test_export_table(self): + test_stream = self.generate_test_csww_file() + test_events = list(message_collector(test_stream)) + dataset_holder, stream = export_table(test_events) + + self.assertEqual(len(list(stream)), 3) + + data = dataset_holder.value + self.assertEqual(len(data), 3) + self.assertEqual(data["Header"], [{"Version": "1"}]) + self.assertEqual(data["LA_Level"], [{"NumberOfVacancies": "100"}]) + self.assertEqual( + data["Worker"], [{"ID": "100", "SWENo": "AB123456789", "Agency": "0"}] + ) diff --git a/tests/social_work_workforce/test_stream_filters.py b/tests/social_work_workforce/test_stream_filters.py index 8b631918..66cd0f3d 100644 --- a/tests/social_work_workforce/test_stream_filters.py +++ b/tests/social_work_workforce/test_stream_filters.py @@ -1,217 +1,13 @@ from collections import namedtuple -import xml.etree.ElementTree as ET -from io import BytesIO -from typing import Iterable -from sfdata_stream_parser.events import StartElement, EndElement, TextNode, ParseEvent -from liiatools.csww_pipeline.stream_parse import dom_parse +from sfdata_stream_parser.events import TextNode from liiatools.common.spec.__data_schema import ( Column, Numeric, Category, ) -from liiatools.csww_pipeline.spec import ( - load_schema, - load_schema_path, -) -from liiatools.csww_pipeline.spec.samples import CSWW_2022 -from liiatools.csww_pipeline.stream_filters import ( - strip_text, - add_context, - add_schema, - validate_elements, - _create_category_spec, - _create_numeric_spec, - _create_regex_spec, - add_column_spec, -) - - -def test_strip_text(): - stream = [ - TextNode(text=None), - TextNode(cell=None, text=None), - TextNode(cell="string", text=None), - TextNode(cell=" string_with_whitespace ", text=None), - ] - - stripped_stream = list(strip_text(stream)) - assert stream[0] == stripped_stream[0] - assert stream[1] == stripped_stream[1] - assert stripped_stream[2].cell == "string" - assert stripped_stream[3].cell == "string_with_whitespace" - - -def test_add_context(): - stream = [ - StartElement(tag="Message"), - StartElement(tag="Header"), - TextNode(cell="string", text=None), - EndElement(tag="Header"), - EndElement(tag="Message"), - ] - - context_stream = list(add_context(stream)) - assert context_stream[0].context == ("Message",) - assert context_stream[1].context == ("Message", "Header") - assert context_stream[2].context == ("Message", "Header") - assert context_stream[3].context == ("Message", "Header") - assert context_stream[4].context == ("Message",) - - -def test_add_schema(): - schema = load_schema(year=2022) - stream = [ - StartElement(tag="Message", context=("Message",)), - StartElement(tag="Header", context=("Message", "Header")), - TextNode(cell="string", text=None, context=("Message", "Header")), - EndElement(tag="Header", context=("Message", "Header")), - EndElement(tag="Message", context=("Message",)), - ] - - schema_stream = list(add_schema(stream, schema=schema)) - - assert schema_stream[0].schema.name == "Message" - assert schema_stream[0].schema.occurs == (1, 1) - assert schema_stream[1].schema.name == "Header" - assert schema_stream[1].schema.occurs == (0, 1) - assert schema_stream[2].schema.name == "Header" - assert schema_stream[2].schema.occurs == (0, 1) - assert schema_stream[3].schema.name == "Header" - assert schema_stream[3].schema.occurs == (0, 1) - assert schema_stream[4].schema.name == "Message" - assert schema_stream[4].schema.occurs == (1, 1) - - -def _xml_to_stream(root) -> Iterable[ParseEvent]: - schema = load_schema(2022) - - input = BytesIO(ET.tostring(root, encoding="utf-8")) - stream = dom_parse(input, filename="test.xml") - stream = strip_text(stream) - stream = add_context(stream) - stream = add_schema(stream, schema=schema) - stream = validate_elements(stream) - return list(stream) - - -def test_validate_all_valid(): - with CSWW_2022.open("rb") as f: - root = ET.parse(f).getroot() - - stream = _xml_to_stream(root) - - for event in stream: - assert not hasattr(event, "errors") - - -def test_validate_missing_required_field(): - with CSWW_2022.open("rb") as f: - root = ET.parse(f).getroot() - - parent = root.find(".//CSWWWorker") - el = parent.find("AgencyWorker") - parent.remove(el) - - stream = _xml_to_stream(root) - - errors = [] - for event in stream: - if hasattr(event, "errors"): - errors.append(event.errors) - - assert list(errors[0])[0] == { - "type": "ValidationError", - "message": "Invalid node", - "exception": "Missing required field: 'AgencyWorker' which occurs in the node starting on line: 20", - } - - -def test_validate_reordered_required_field(): - with CSWW_2022.open("rb") as f: - root = ET.parse(f).getroot() - - el_parent = root.find(".//AgencyWorker/..") - el_child_id = el_parent.find("AgencyWorker") - el_parent.remove(el_child_id) - el_parent.append(el_child_id) - - stream = _xml_to_stream(root) - - errors = [] - for event in stream: - if hasattr(event, "errors"): - errors.append(event.errors) - - assert list(errors[0])[0] == { - "type": "ValidationError", - "message": "Invalid node", - "exception": "Missing required field: 'AgencyWorker' which occurs in the node starting on line: 20", - } - - -def test_validate_unexpected_node(): - with CSWW_2022.open("rb") as f: - root = ET.parse(f).getroot() - - parent = root.find(".//CSWWWorker") - ET.SubElement(parent, "Unknown_Node") - - stream = _xml_to_stream(root) - - errors = [] - for event in stream: - if hasattr(event, "errors"): - errors.append(event.errors) - - assert list(errors[0])[0] == { - "type": "ValidationError", - "message": "Invalid node", - "exception": "Unexpected node 'Unknown_Node'", - } - - -def test_create_category_spec(): - schema_path = load_schema_path(2022) - field = "agencyworkertype" - category_spec = _create_category_spec(field, schema_path) - - assert category_spec == [ - Category( - code="0", - name="Not an Agency Worker", - cell_regex=None, - model_config={"extra": "forbid"}, - ), - Category( - code="1", - name="Agency Worker", - cell_regex=None, - model_config={"extra": "forbid"}, - ), - ] - - -def test_create_numeric_spec(): - schema_path = load_schema_path(2022) - field = "twodecimalplaces" - numeric_spec = _create_numeric_spec(field, schema_path) - - assert numeric_spec == Numeric( - type="float", - min_value=None, - max_value=None, - decimal_places=2, - model_config={"extra": "forbid"}, - ) - - -def test_create_regex_spec(): - schema_path = load_schema_path(2022) - field = "swetype" - regex_spec = _create_regex_spec(field, schema_path) - - assert regex_spec == r"[A-Za-z]{2}\d{10}" +from liiatools.csww_pipeline.spec import load_schema_path +from liiatools.csww_pipeline.stream_filters import add_column_spec def test_add_column_spec(): diff --git a/tests/social_work_workforce/test_stream_pipeline.py b/tests/social_work_workforce/test_stream_pipeline.py index fddef523..5e47e35c 100644 --- a/tests/social_work_workforce/test_stream_pipeline.py +++ b/tests/social_work_workforce/test_stream_pipeline.py @@ -60,4 +60,4 @@ def test_task_cleanfile_error(): assert errors[0]["filename"] == "social_work_workforce_2022_error.xml" assert errors[0]["header"] == "DateTime" - os.remove(SAMPLES_DIR / 'social_work_workforce_2022_error.xml') + os.remove(SAMPLES_DIR / "social_work_workforce_2022_error.xml") diff --git a/tests/social_work_workforce/test_stream_record.py b/tests/social_work_workforce/test_stream_record.py index 9c06ee09..71479fbf 100644 --- a/tests/social_work_workforce/test_stream_record.py +++ b/tests/social_work_workforce/test_stream_record.py @@ -6,28 +6,10 @@ CSWWEvent, LALevelEvent, HeaderEvent, - _reduce_dict, - text_collector, message_collector, - export_table, ) -def test_reduce_dict(): - sample_dict = { - "ID": ["100"], - "SWENo": ["AB123456789"], - "Agency": ["0"], - "ReasonAbsence": ["MAT", "OTH"], - } - assert _reduce_dict(sample_dict) == { - "ID": "100", - "SWENo": "AB123456789", - "Agency": "0", - "ReasonAbsence": ["MAT", "OTH"], - } - - class TestRecord(unittest.TestCase): def generate_text_element(self, tag: str, cell): """ @@ -61,22 +43,6 @@ def generate_test_csww_file(self): yield EndElement(tag="CSWWWorker") yield EndElement(tag="Message") - def test_text_collector(self): - # test that the text_collector returns a dictionary of events and their text values from the stream - test_stream = self.generate_test_csww_file() - test_record = text_collector(test_stream) - self.assertEqual(len(test_record), 5) - self.assertEqual( - test_record, - { - "Version": "1", - "NumberOfVacancies": "100", - "ID": "100", - "SWENo": "AB123456789", - "Agency": "0", - }, - ) - def test_message_collector(self): # test that the message_collector yields events of the correct type from the stream test_stream = self.generate_test_csww_file() @@ -90,18 +56,3 @@ def test_message_collector(self): self.assertEqual( test_events[2].record, {"ID": "100", "SWENo": "AB123456789", "Agency": "0"} ) - - def test_export_table(self): - test_stream = self.generate_test_csww_file() - test_events = list(message_collector(test_stream)) - dataset_holder, stream = export_table(test_events) - - self.assertEqual(len(list(stream)), 3) - - data = dataset_holder.value - self.assertEqual(len(data), 3) - self.assertEqual(data["Header"], [{"Version": "1"}]) - self.assertEqual(data["LA_Level"], [{"NumberOfVacancies": "100"}]) - self.assertEqual( - data["Worker"], [{"ID": "100", "SWENo": "AB123456789", "Agency": "0"}] - )