Skip to content

Commit 080b954

Browse files
Add a mixin to make the task skippable (#7)
1 parent 0901ed5 commit 080b954

File tree

2 files changed

+217
-0
lines changed

2 files changed

+217
-0
lines changed

data_validation_framework/task.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import time
55
import traceback
66
import warnings
7+
from functools import partial
78
from pathlib import Path
89

910
import luigi
@@ -21,13 +22,15 @@
2122
from numpy import VisibleDeprecationWarning
2223

2324
from data_validation_framework.report import make_report
25+
from data_validation_framework.result import ValidationResult
2426
from data_validation_framework.result import ValidationResultSet
2527
from data_validation_framework.target import ReportTarget
2628
from data_validation_framework.target import TaggedOutputLocalTarget
2729
from data_validation_framework.util import apply_to_df
2830

2931
L = logging.getLogger(__name__)
3032
INDEX_LABEL = "__index_label__"
33+
SKIP_COMMENT = "Skipped by user."
3134

3235

3336
class ValidationError(Exception):
@@ -660,3 +663,63 @@ def validation_function(*args, **kwargs):
660663
This method should usually do nothing for :class:`ValidationWorkflow` as this class is only
661664
supposed to gather validation steps.
662665
"""
666+
667+
668+
def _skippable_element_validation_function(validation_function, skip, *args, **kwargs):
669+
"""Skipping wrapper for an element validation function."""
670+
if skip:
671+
return ValidationResult(is_valid=True, comment=SKIP_COMMENT)
672+
return validation_function(*args, **kwargs)
673+
674+
675+
def _skippable_set_validation_function(validation_function, skip, *args, **kwargs):
676+
"""Skipping wrapper for a set validation function."""
677+
df = kwargs.get("df", args[0])
678+
if skip:
679+
df.loc[df["is_valid"], "comment"] = SKIP_COMMENT
680+
else:
681+
validation_function(*args, **kwargs)
682+
683+
684+
def SkippableMixin(default_value=False):
685+
"""Create a mixin class to add a ``skip`` parameter.
686+
687+
This mixin must be applied to a :class:`data_validation_framework.ElementValidationTask`.
688+
It will create a ``skip`` parameter and wrap the validation function to just skip it if the
689+
``skip`` argument is set to ``True``. If skipped, it will keep the ``is_valid`` values as is and
690+
add a specific comment to inform the user.
691+
692+
Args:
693+
default_value (bool): The default value for the ``skip`` argument.
694+
"""
695+
696+
class Mixin:
697+
"""A mixin to add a ``skip`` parameter to a :class:`luigi.task`."""
698+
699+
skip = BoolParameter(default=default_value, description=":bool: Skip the task")
700+
701+
def __init__(self, *args, **kwargs):
702+
703+
super().__init__(*args, **kwargs)
704+
705+
if isinstance(self, ElementValidationTask):
706+
new_validation_function = partial(
707+
_skippable_element_validation_function,
708+
self.validation_function,
709+
self.skip,
710+
)
711+
elif isinstance(self, SetValidationTask) and not isinstance(self, ValidationWorkflow):
712+
new_validation_function = partial(
713+
_skippable_set_validation_function,
714+
self.validation_function,
715+
self.skip,
716+
)
717+
else:
718+
raise TypeError(
719+
"The SkippableMixin can only be associated with childs of ElementValidationTask"
720+
" or SetValidationTask"
721+
)
722+
self._skippable_validation_function = self.validation_function
723+
self.validation_function = new_validation_function
724+
725+
return Mixin

tests/test_task.py

Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1568,3 +1568,157 @@ def test_nested_workflows(
15681568
data_dir / "test_report_before_run" / "report_rst2pdf_nested.pdf",
15691569
threshold=25,
15701570
)
1571+
1572+
1573+
class TestSkippableMixin:
1574+
"""Test the data_validation_framework.task.SkippableMixin class."""
1575+
1576+
def test_fail_parent_type(self):
1577+
err_msg = (
1578+
"The SkippableMixin can only be associated with childs of ElementValidationTask"
1579+
" or SetValidationTask"
1580+
)
1581+
1582+
class TestTask1(task.SkippableMixin(), luigi.Task):
1583+
pass
1584+
1585+
with pytest.raises(
1586+
TypeError,
1587+
match=err_msg,
1588+
):
1589+
TestTask1()
1590+
1591+
class TestTask2(task.SkippableMixin(), task.ValidationWorkflow):
1592+
pass
1593+
1594+
with pytest.raises(
1595+
TypeError,
1596+
match=err_msg,
1597+
):
1598+
TestTask2()
1599+
1600+
def test_skip_element_task(self, dataset_df_path, tmpdir):
1601+
class TestSkippableTask(task.SkippableMixin(), task.ElementValidationTask):
1602+
@staticmethod
1603+
# pylint: disable=arguments-differ
1604+
def validation_function(row, output_path, *args, **kwargs):
1605+
if row["a"] <= 1:
1606+
return result.ValidationResult(is_valid=True)
1607+
if row["a"] <= 2:
1608+
return result.ValidationResult(is_valid=False, comment="bad value")
1609+
raise ValueError(f"Incorrect value {row['a']}")
1610+
1611+
# Test with no given skip value (should be False by default)
1612+
assert luigi.build(
1613+
[
1614+
TestSkippableTask(
1615+
dataset_df=dataset_df_path, result_path=str(tmpdir / "out_default")
1616+
)
1617+
],
1618+
local_scheduler=True,
1619+
)
1620+
1621+
report_data = pd.read_csv(tmpdir / "out_default" / "TestSkippableTask" / "report.csv")
1622+
assert (report_data["is_valid"] == [True, False]).all()
1623+
assert (report_data["comment"].isnull() == [True, False]).all()
1624+
assert report_data.loc[1, "comment"] == "bad value"
1625+
assert report_data["exception"].isnull().all()
1626+
1627+
# Test with no given skip value (should be False by default)
1628+
assert luigi.build(
1629+
[
1630+
TestSkippableTask(
1631+
dataset_df=dataset_df_path, result_path=str(tmpdir / "out_no_skip"), skip=False
1632+
)
1633+
],
1634+
local_scheduler=True,
1635+
)
1636+
1637+
report_data = pd.read_csv(tmpdir / "out_no_skip" / "TestSkippableTask" / "report.csv")
1638+
assert (report_data["is_valid"] == [True, False]).all()
1639+
assert (report_data["comment"].isnull() == [True, False]).all()
1640+
assert report_data.loc[1, "comment"] == "bad value"
1641+
assert report_data["exception"].isnull().all()
1642+
1643+
# Test with no given skip value (should be False by default)
1644+
assert luigi.build(
1645+
[
1646+
TestSkippableTask(
1647+
dataset_df=dataset_df_path, result_path=str(tmpdir / "out_skip"), skip=True
1648+
)
1649+
],
1650+
local_scheduler=True,
1651+
)
1652+
1653+
report_data = pd.read_csv(tmpdir / "out_skip" / "TestSkippableTask" / "report.csv")
1654+
assert (
1655+
report_data["is_valid"] == True # noqa ; pylint: disable=singleton-comparison
1656+
).all()
1657+
assert (report_data["comment"] == "Skipped by user.").all()
1658+
assert report_data["exception"].isnull().all()
1659+
1660+
def test_skip_set_task(self, dataset_df_path, tmpdir):
1661+
class TestSkippableTask(task.SkippableMixin(), task.SetValidationTask):
1662+
@staticmethod
1663+
def validation_function(df, output_path, *args, **kwargs):
1664+
# pylint: disable=no-member
1665+
df["a"] *= 10
1666+
df.loc[1, "is_valid"] = False
1667+
df.loc[1, "ret_code"] = 1
1668+
df[["a", "b"]].to_csv(output_path / "test.csv")
1669+
1670+
# Test with no given skip value (should be False by default)
1671+
assert luigi.build(
1672+
[
1673+
TestSkippableTask(
1674+
dataset_df=dataset_df_path, result_path=str(tmpdir / "out_default")
1675+
)
1676+
],
1677+
local_scheduler=True,
1678+
)
1679+
1680+
res = pd.read_csv(tmpdir / "out_default" / "TestSkippableTask" / "data" / "test.csv")
1681+
expected = pd.read_csv(tmpdir / "dataset.csv")
1682+
expected["a"] *= 10
1683+
assert res.equals(expected)
1684+
report_data = pd.read_csv(tmpdir / "out_default" / "TestSkippableTask" / "report.csv")
1685+
assert (report_data["is_valid"] == [True, False]).all()
1686+
assert report_data["comment"].isnull().all()
1687+
assert report_data["exception"].isnull().all()
1688+
1689+
# Test with skip = False
1690+
assert luigi.build(
1691+
[
1692+
TestSkippableTask(
1693+
dataset_df=dataset_df_path, result_path=str(tmpdir / "out_no_skip"), skip=False
1694+
)
1695+
],
1696+
local_scheduler=True,
1697+
)
1698+
1699+
res = pd.read_csv(tmpdir / "out_no_skip" / "TestSkippableTask" / "data" / "test.csv")
1700+
expected = pd.read_csv(tmpdir / "dataset.csv")
1701+
expected["a"] *= 10
1702+
assert res.equals(expected)
1703+
report_data = pd.read_csv(tmpdir / "out_no_skip" / "TestSkippableTask" / "report.csv")
1704+
assert (report_data["is_valid"] == [True, False]).all()
1705+
assert report_data["comment"].isnull().all()
1706+
assert report_data["exception"].isnull().all()
1707+
1708+
# Test with skip = True
1709+
assert luigi.build(
1710+
[
1711+
TestSkippableTask(
1712+
dataset_df=dataset_df_path, result_path=str(tmpdir / "out_skip"), skip=True
1713+
)
1714+
],
1715+
local_scheduler=True,
1716+
)
1717+
1718+
assert not (tmpdir / "out_skip" / "TestSkippableTask" / "data" / "test.csv").exists()
1719+
report_data = pd.read_csv(tmpdir / "out_skip" / "TestSkippableTask" / "report.csv")
1720+
assert (
1721+
report_data["is_valid"] == True # noqa ; pylint: disable=singleton-comparison
1722+
).all()
1723+
assert (report_data["comment"] == "Skipped by user.").all()
1724+
assert report_data["exception"].isnull().all()

0 commit comments

Comments
 (0)