Skip to content

Commit c10cadb

Browse files
New mechanism for inputs (#10)
New mechanism for inputs BREAKING CHANGE: The old mechanism is deprecated.
1 parent 750f9fb commit c10cadb

File tree

3 files changed

+318
-91
lines changed

3 files changed

+318
-91
lines changed

README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ class ValidationTask2(dvf.task.SetValidationTask):
7777
ValidationTask2."""
7878

7979
def inputs(self):
80-
return {ValidationTask1(): {"col_name": "new_col_name_in_current_task"}}
80+
return {ValidationTask1: {"col_name": "new_col_name_in_current_task"}}
8181

8282
def kwargs(self):
8383
return {"param_value": self.a_parameter}
@@ -90,8 +90,8 @@ class ValidationWorkflow(dvf.task.ValidationWorkflow):
9090

9191
def inputs(self):
9292
return {
93-
ValidationTask1(): {},
94-
ValidationTask2(): {},
93+
ValidationTask1: {},
94+
ValidationTask2: {},
9595
}
9696
```
9797

data_validation_framework/task.py

Lines changed: 113 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,41 @@ def num_tag(path, num):
8686
TaggedOutputLocalTarget.set_default_prefix(self.result_path)
8787

8888

89+
class InputParameters:
90+
"""A helper to build task inputs.
91+
92+
This class should only be used to build the input dictionaries:
93+
94+
.. code-block:: python
95+
96+
class Task(ChildBaseValidationTask):
97+
data_dir = luigi.Parameter(default='my_data')
98+
validation_function = staticmethod(my_validation_function)
99+
output_columns = {'my_new_col': None}
100+
101+
def inputs(self):
102+
return {
103+
PreviousTask: InputParameters(
104+
{"input_data", "input_data"},
105+
kwarg_1="value_1",
106+
kwarg_2="value_2"
107+
)
108+
}
109+
110+
Args:
111+
col_mapping (dict): The column mapping.
112+
113+
Keyword Args:
114+
kwargs: All the keyword arguments passed to the constructor.
115+
116+
.. warning:: The keyword arguments will always override the arguments from the config file.
117+
"""
118+
119+
def __init__(self, col_mapping, **kwargs):
120+
self.col_mapping = col_mapping
121+
self.kwargs = kwargs
122+
123+
89124
class BaseValidationTask(LogTargetMixin, RerunMixin, TagResultOutputMixin, luigi.Task):
90125
"""Base luigi task used for validation steps.
91126
@@ -103,7 +138,7 @@ class Task(ChildBaseValidationTask):
103138
output_columns = {'my_new_col': None}
104139
105140
def inputs(self):
106-
return {'PreviousTask(): {"input_data", "input_data"}
141+
return {PreviousTask: {"input_data", "input_data"}
107142
108143
.. note::
109144
@@ -195,17 +230,41 @@ def success_summary(self):
195230
def inputs(self):
196231
"""Information about required input data.
197232
198-
This method can be overridden and should return a dict of the following form:
233+
This method can be overridden and should return a dict of one of the following forms:
199234
200235
.. code-block:: python
201236
202-
{<task_name>(): {"<input_column_name>": "<current_column_name>"}}
237+
{<task_name>: {"<input_column_name>": "<current_column_name>"}}
238+
239+
.. code-block:: python
240+
241+
{
242+
<task_name>: (
243+
{"<input_column_name>": "<current_column_name>"},
244+
{
245+
"<kwarg_1>": "<value_1>",
246+
"<kwarg_2>": "<value_2>",
247+
}
248+
)
249+
}
250+
251+
.. code-block:: python
252+
253+
{
254+
<task_name>: InputParameters(
255+
{"<input_column_name>": "<current_column_name>"},
256+
kwarg_1="<value_1>",
257+
kwarg_2="<value_2>",
258+
)
259+
}
203260
204261
where:
205262
- ``<task_name>`` is the name of the required task,
206263
- ``<input_column_name>`` is the name of the column we need from the report of
207264
task_name,
208265
- ``<current_column_name>`` is the name of the same column in the current task.
266+
- ``<kwarg_*>`` is the name of a keyword argument passed to the constructor of the
267+
required task.
209268
"""
210269
# pylint: disable=no-self-use
211270
return None
@@ -243,22 +302,52 @@ def pre_process(self, df, args, kwargs):
243302
def post_process(self, df, args, kwargs):
244303
"""Method executed after applying the external function."""
245304

305+
def processed_inputs(self):
306+
"""Process the inputs to automatically propagate the values from the workflow."""
307+
inputs = self.inputs() # pylint: disable=assignment-from-none
308+
309+
if not inputs:
310+
return None
311+
312+
def default_kwargs(self):
313+
return {
314+
"dataset_df": self.dataset_df,
315+
"input_index_col": self.input_index_col,
316+
"result_path": self.result_path,
317+
"nb_processes": self.nb_processes,
318+
"redirect_stdout": self.redirect_stdout,
319+
}
320+
321+
formatted_inputs = {}
322+
for task, v in inputs.items():
323+
if isinstance(v, dict):
324+
col_mapping = v
325+
specific_kwargs = {}
326+
elif isinstance(v, InputParameters):
327+
col_mapping = v.col_mapping
328+
specific_kwargs = v.kwargs
329+
else:
330+
try:
331+
specific_kwargs, col_mapping = v
332+
except ValueError as exc:
333+
raise ValueError(
334+
"The input values should either be a dict containing the column mapping "
335+
"or a tuple with a dict containing the keyword arguments as first element "
336+
"and a dict containing the column mapping as second element."
337+
) from exc
338+
kwargs = default_kwargs(self)
339+
base_kwargs = {
340+
key: value
341+
for key, value in task.get_param_values(task.get_params(), [], specific_kwargs)
342+
if value is not None
343+
}
344+
kwargs.update(base_kwargs)
345+
formatted_inputs[task(**kwargs)] = col_mapping
346+
return formatted_inputs
347+
246348
def requires(self):
247349
"""Process the inputs to generate the requirements."""
248-
if self.inputs():
249-
requires = list(self.inputs().keys()) # pylint: disable=not-callable
250-
for req in requires:
251-
if req.dataset_df is None:
252-
req.dataset_df = self.dataset_df
253-
req.input_index_col = self.input_index_col
254-
if req.result_path is None:
255-
req.result_path = self.result_path
256-
if req.nb_processes is None:
257-
req.nb_processes = self.nb_processes
258-
if req.redirect_stdout is None:
259-
req.redirect_stdout = self.redirect_stdout
260-
else:
261-
requires = []
350+
requires = list(self.processed_inputs() or [])
262351
return requires + task_flatten(self.extra_requires())
263352

264353
def extra_requires(self):
@@ -364,9 +453,11 @@ def _get_dataset(self):
364453

365454
def _join_inputs(self, new_df):
366455
"""Get the inputs and join them to the dataset."""
367-
if self.inputs():
456+
if self.processed_inputs():
368457
# Check inputs
369-
with_mapping = self.check_inputs(self.inputs()) # pylint: disable=not-callable
458+
with_mapping = self.check_inputs(
459+
self.processed_inputs()
460+
) # pylint: disable=not-callable
370461

371462
# Get the input targets and their DataFrames
372463
all_inputs = {
@@ -375,7 +466,7 @@ def _join_inputs(self, new_df):
375466
for target in task_flatten(i.output())
376467
if isinstance(target, ReportTarget)
377468
]
378-
for i in task_flatten(self.inputs().keys())
469+
for i in task_flatten(self.processed_inputs().keys())
379470
}
380471
all_report_paths = {
381472
t: [r.path for r in reports][0] for t, reports in all_inputs.items()
@@ -396,7 +487,7 @@ def _join_inputs(self, new_df):
396487
# Filter columns in the DataFrames
397488
if with_mapping:
398489
# pylint: disable=not-callable
399-
filtered_dfs = self.filter_columns(all_dfs, self.inputs())
490+
filtered_dfs = self.filter_columns(all_dfs, self.processed_inputs())
400491

401492
# Concatenate all DataFrames
402493
filtered_df = pd.concat(filtered_dfs, axis=1)
@@ -443,7 +534,7 @@ def _join_inputs(self, new_df):
443534
def run(self):
444535
"""The main process of the current task."""
445536
# Import the DataFrame(s)
446-
if self.dataset_df is None and self.inputs() is None:
537+
if self.dataset_df is None and self.processed_inputs() is None:
447538
raise ValueError("Either the 'dataset_df' parameter or a requirement must be provided.")
448539

449540
new_df = self._get_dataset()

0 commit comments

Comments
 (0)