diff --git a/.coveragerc b/.coveragerc index 3ad2277..06caa1b 100644 --- a/.coveragerc +++ b/.coveragerc @@ -3,5 +3,5 @@ omit = tests/* [report] -exclude_also = +exclude_also = if __name__ == .__main__.: diff --git a/.readthedocs.yaml b/.readthedocs.yaml index b35c772..c542ae6 100644 --- a/.readthedocs.yaml +++ b/.readthedocs.yaml @@ -31,5 +31,3 @@ python: install: - method: pip path: .[docs] - - diff --git a/.vscode/settings.json b/.vscode/settings.json index 098ee18..5d2e2cb 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -4,6 +4,12 @@ ], "python.testing.unittestEnabled": false, "python.testing.pytestEnabled": true, - "python.linting.flake8Enabled": true, - "python.linting.enabled": true + "flake8.args": [ + "--max-line-length=88", + "--extend-ignore=E203", + ], + "[python]": { + "editor.defaultFormatter": "ms-python.black-formatter" + }, + "python.testing.cwd": "${workspaceFolder}" } diff --git a/adtl/__init__.py b/adtl/__init__.py index ead6f4c..00afaf1 100644 --- a/adtl/__init__.py +++ b/adtl/__init__.py @@ -86,12 +86,14 @@ def get_value_unhashed(row: StrDict, rule: Rule, ctx: Context = None) -> Any: params.append(row[rule["apply"]["params"][i][1:]]) elif isinstance(rule["apply"]["params"][i], list): param = [ - row[rule["apply"]["params"][i][j][1:]] - if ( - isinstance(rule["apply"]["params"][i][j], str) - and rule["apply"]["params"][i][j].startswith("$") + ( + row[rule["apply"]["params"][i][j][1:]] + if ( + isinstance(rule["apply"]["params"][i][j], str) + and rule["apply"]["params"][i][j].startswith("$") + ) + else rule["apply"]["params"][i][j] ) - else rule["apply"]["params"][i][j] for j in range(len(rule["apply"]["params"][i])) ] params.append(param) @@ -590,9 +592,11 @@ def ctx(self, attribute: str): "defaultDateFormat": self.header.get( "defaultDateFormat", DEFAULT_DATE_FORMAT ), - "skip_pattern": re.compile(self.header.get("skipFieldPattern")) - if self.header.get("skipFieldPattern") - else False, + "skip_pattern": ( + re.compile(self.header.get("skipFieldPattern")) + if self.header.get("skipFieldPattern") + else False + ), } def validate_spec(self): @@ -759,12 +763,14 @@ def parse(self, file: str, encoding: str = "utf-8", skip_validation=False): with open(file, encoding=encoding) as fp: reader = csv.DictReader(fp) return self.parse_rows( - tqdm( - reader, - desc=f"[{self.name}] parsing {Path(file).name}", - ) - if not self.quiet - else reader, + ( + tqdm( + reader, + desc=f"[{self.name}] parsing {Path(file).name}", + ) + if not self.quiet + else reader + ), skip_validation=skip_validation, ) @@ -864,6 +870,55 @@ def writerows(fp, table): buf = io.StringIO() return writerows(buf, table).getvalue() + def write_parquet( + self, + table: str, + output: Optional[str] = None, + ) -> Optional[str]: + """Writes to output as parquet a particular table + + Args: + table: Table that should be written to parquet + output: (optional) Output file name. If not specified, defaults to parser + name + table name with a parquet suffix. + """ + + try: + import polars as pl + except ImportError: + raise ImportError( + "Parquet output requires the polars library. " + "Install with 'pip install polars'" + ) + + # Read the table data + data = list(self.read_table(table)) + + # Convert data to Polars DataFrame + df = pl.DataFrame(data, infer_schema_length=len(data)) + + if table in self.validators: + valid_cols = [c for c in ["adtl_valid", "adtl_error"] if c in df.columns] + df_validated = df.select( + valid_cols + + [ + *[ + col + for col in df.columns + if (col != "adtl_valid" and col != "adtl_error") + ], # All other columns, in their original order + ] + ) + else: + df_validated = df + + if output: + df_validated.write_parquet(output) + else: + buf = io.BytesIO() + df_validated.write_parquet(buf) + return buf.getvalue() + def show_report(self): "Shows report with validation errors" if self.report_available: @@ -883,15 +938,20 @@ def show_report(self): print(f"* {count}: {message}") print() - def save(self, output: Optional[str] = None): + def save(self, output: Optional[str] = None, parquet=False): """Saves all tables to CSV Args: output: (optional) Filename prefix that is used for all tables """ - for table in self.tables: - self.write_csv(table, f"{output}-{table}.csv") + if parquet: + for table in self.tables: + self.write_parquet(table, f"{output}-{table}.parquet") + + else: + for table in self.tables: + self.write_csv(table, f"{output}-{table}.csv") def main(argv=None): @@ -910,6 +970,9 @@ def main(argv=None): cmd.add_argument( "--encoding", help="encoding input file is in", default="utf-8-sig" ) + cmd.add_argument( + "--parquet", help="output file is in parquet format", action="store_true" + ) cmd.add_argument( "-q", "--quiet", @@ -929,7 +992,7 @@ def main(argv=None): # run adtl adtl_output = spec.parse(args.file, encoding=args.encoding) - adtl_output.save(args.output or spec.name) + adtl_output.save(args.output or spec.name, args.parquet) if args.save_report: adtl_output.report.update( dict( diff --git a/docs/adtl.rst b/docs/adtl.rst index 8dd2e4b..d6247f8 100644 --- a/docs/adtl.rst +++ b/docs/adtl.rst @@ -3,4 +3,4 @@ Module reference ===================== .. autoclass:: adtl.Parser - :members: \ No newline at end of file + :members: diff --git a/docs/getting_started/installation.md b/docs/getting_started/installation.md index 0ab0748..dae4cc9 100644 --- a/docs/getting_started/installation.md +++ b/docs/getting_started/installation.md @@ -26,4 +26,3 @@ If you are writing code which depends on adtl (instead of using the command-line program), then it is best to add a dependency on `git+https://github.com/globaldothealth/adtl` to your Python build tool of choice. - diff --git a/docs/index.md b/docs/index.md index da5b831..0bb9afc 100644 --- a/docs/index.md +++ b/docs/index.md @@ -47,4 +47,4 @@ maxdepth: 1 adtl transformations -``` \ No newline at end of file +``` diff --git a/docs/transformations.rst b/docs/transformations.rst index d8c9032..be0c85f 100644 --- a/docs/transformations.rst +++ b/docs/transformations.rst @@ -3,4 +3,4 @@ Transformations ===================== .. automodule:: adtl.transformations - :members: \ No newline at end of file + :members: diff --git a/pyproject.toml b/pyproject.toml index f3d78d2..640be5c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,13 +34,17 @@ test = [ "pytest-cov", "syrupy==4.*", "responses", - "pytest-unordered" + "pytest-unordered", + "adtl[parquet]" ] docs = [ "sphinx>=7.2.2", "myst-parser==2.0.*", "furo" ] +parquet = [ + "polars" +] [project.urls] Home = "https://github.com/globaldothealth/adtl" diff --git a/schemas/dev.schema.json b/schemas/dev.schema.json index c794ca8..3624c12 100644 --- a/schemas/dev.schema.json +++ b/schemas/dev.schema.json @@ -226,4 +226,4 @@ } } } -} \ No newline at end of file +} diff --git a/tests/parsers/epoch-web-schema.json b/tests/parsers/epoch-web-schema.json index 8f58f18..d8c0e66 100644 --- a/tests/parsers/epoch-web-schema.json +++ b/tests/parsers/epoch-web-schema.json @@ -24,4 +24,4 @@ "field": "Text" } } -} \ No newline at end of file +} diff --git a/tests/parsers/epoch.yml b/tests/parsers/epoch.yml index f548dce..4f0672f 100644 --- a/tests/parsers/epoch.yml +++ b/tests/parsers/epoch.yml @@ -1,11 +1,11 @@ adtl: name: default-date-format description: Tests default date format - defaultDateFormat: %d/%m/%Y + defaultDateFormat: "%d/%m/%Y" tables: - table: { kind: oneToOne, schema: ../schemas/epoch-data.schema.json } table: id: { field: Entry_ID } epoch: { field: Epoch } some_date: { field: SomeDate } - text: { field: Text } \ No newline at end of file + text: { field: Text } diff --git a/tests/parsers/skip_field.json b/tests/parsers/skip_field.json index 97079cc..3998743 100644 --- a/tests/parsers/skip_field.json +++ b/tests/parsers/skip_field.json @@ -31,4 +31,4 @@ "can_skip": true } } -} \ No newline at end of file +} diff --git a/tests/schemas/epoch-data.schema.json b/tests/schemas/epoch-data.schema.json index abdc5f7..8433465 100644 --- a/tests/schemas/epoch-data.schema.json +++ b/tests/schemas/epoch-data.schema.json @@ -29,4 +29,4 @@ "description": "Follow-up cough field" } } -} \ No newline at end of file +} diff --git a/tests/schemas/epoch-oneOf.schema.json b/tests/schemas/epoch-oneOf.schema.json index 2c67d56..4a29d22 100644 --- a/tests/schemas/epoch-oneOf.schema.json +++ b/tests/schemas/epoch-oneOf.schema.json @@ -58,4 +58,4 @@ "description": "sex" } } -} \ No newline at end of file +} diff --git a/tests/schemas/observation_defaultif.schema.json b/tests/schemas/observation_defaultif.schema.json index 525f42c..ba16658 100644 --- a/tests/schemas/observation_defaultif.schema.json +++ b/tests/schemas/observation_defaultif.schema.json @@ -72,4 +72,4 @@ ] } } -} \ No newline at end of file +} diff --git a/tests/sources/oneToMany.csv b/tests/sources/oneToMany.csv index d74ea7b..fff0d1d 100644 --- a/tests/sources/oneToMany.csv +++ b/tests/sources/oneToMany.csv @@ -1,2 +1,2 @@ dt,headache_cmyn,cough_cmyn,dyspnea_cmyn -2022-02-05,1,1,0 \ No newline at end of file +2022-02-05,1,1,0 diff --git a/tests/test_parser.py b/tests/test_parser.py index 279e079..b946939 100644 --- a/tests/test_parser.py +++ b/tests/test_parser.py @@ -1213,6 +1213,12 @@ def test_main(snapshot): Path("output-table.csv").unlink() +def test_main_parquet(): + parser.main(ARGV + ["--parquet"]) + assert Path("output-table.parquet") + Path("output-table.parquet").unlink() + + @responses.activate def test_main_web_schema(snapshot): # test with schema on the web