@@ -86,6 +86,41 @@ def num_tag(path, num):
8686 TaggedOutputLocalTarget .set_default_prefix (self .result_path )
8787
8888
89+ class InputParameters :
90+ """A helper to build task inputs.
91+
92+ This class should only be used to build the input dictionaries:
93+
94+ .. code-block:: python
95+
96+ class Task(ChildBaseValidationTask):
97+ data_dir = luigi.Parameter(default='my_data')
98+ validation_function = staticmethod(my_validation_function)
99+ output_columns = {'my_new_col': None}
100+
101+ def inputs(self):
102+ return {
103+ PreviousTask: InputParameters(
104+ {"input_data", "input_data"},
105+ kwarg_1="value_1",
106+ kwarg_2="value_2"
107+ )
108+ }
109+
110+ Args:
111+ col_mapping (dict): The column mapping.
112+
113+ Keyword Args:
114+ kwargs: All the keyword arguments passed to the constructor.
115+
116+ .. warning:: The keyword arguments will always override the arguments from the config file.
117+ """
118+
119+ def __init__ (self , col_mapping , ** kwargs ):
120+ self .col_mapping = col_mapping
121+ self .kwargs = kwargs
122+
123+
89124class BaseValidationTask (LogTargetMixin , RerunMixin , TagResultOutputMixin , luigi .Task ):
90125 """Base luigi task used for validation steps.
91126
@@ -103,7 +138,7 @@ class Task(ChildBaseValidationTask):
103138 output_columns = {'my_new_col': None}
104139
105140 def inputs(self):
106- return {' PreviousTask() : {"input_data", "input_data"}
141+ return {PreviousTask: {"input_data", "input_data"}
107142
108143 .. note::
109144
@@ -195,17 +230,41 @@ def success_summary(self):
195230 def inputs (self ):
196231 """Information about required input data.
197232
198- This method can be overridden and should return a dict of the following form :
233+ This method can be overridden and should return a dict of one of the following forms :
199234
200235 .. code-block:: python
201236
202- {<task_name>(): {"<input_column_name>": "<current_column_name>"}}
237+ {<task_name>: {"<input_column_name>": "<current_column_name>"}}
238+
239+ .. code-block:: python
240+
241+ {
242+ <task_name>: (
243+ {"<input_column_name>": "<current_column_name>"},
244+ {
245+ "<kwarg_1>": "<value_1>",
246+ "<kwarg_2>": "<value_2>",
247+ }
248+ )
249+ }
250+
251+ .. code-block:: python
252+
253+ {
254+ <task_name>: InputParameters(
255+ {"<input_column_name>": "<current_column_name>"},
256+ kwarg_1="<value_1>",
257+ kwarg_2="<value_2>",
258+ )
259+ }
203260
204261 where:
205262 - ``<task_name>`` is the name of the required task,
206263 - ``<input_column_name>`` is the name of the column we need from the report of
207264 task_name,
208265 - ``<current_column_name>`` is the name of the same column in the current task.
266+ - ``<kwarg_*>`` is the name of a keyword argument passed to the constructor of the
267+ required task.
209268 """
210269 # pylint: disable=no-self-use
211270 return None
@@ -243,22 +302,52 @@ def pre_process(self, df, args, kwargs):
243302 def post_process (self , df , args , kwargs ):
244303 """Method executed after applying the external function."""
245304
305+ def processed_inputs (self ):
306+ """Process the inputs to automatically propagate the values from the workflow."""
307+ inputs = self .inputs () # pylint: disable=assignment-from-none
308+
309+ if not inputs :
310+ return None
311+
312+ def default_kwargs (self ):
313+ return {
314+ "dataset_df" : self .dataset_df ,
315+ "input_index_col" : self .input_index_col ,
316+ "result_path" : self .result_path ,
317+ "nb_processes" : self .nb_processes ,
318+ "redirect_stdout" : self .redirect_stdout ,
319+ }
320+
321+ formatted_inputs = {}
322+ for task , v in inputs .items ():
323+ if isinstance (v , dict ):
324+ col_mapping = v
325+ specific_kwargs = {}
326+ elif isinstance (v , InputParameters ):
327+ col_mapping = v .col_mapping
328+ specific_kwargs = v .kwargs
329+ else :
330+ try :
331+ specific_kwargs , col_mapping = v
332+ except ValueError as exc :
333+ raise ValueError (
334+ "The input values should either be a dict containing the column mapping "
335+ "or a tuple with a dict containing the keyword arguments as first element "
336+ "and a dict containing the column mapping as second element."
337+ ) from exc
338+ kwargs = default_kwargs (self )
339+ base_kwargs = {
340+ key : value
341+ for key , value in task .get_param_values (task .get_params (), [], specific_kwargs )
342+ if value is not None
343+ }
344+ kwargs .update (base_kwargs )
345+ formatted_inputs [task (** kwargs )] = col_mapping
346+ return formatted_inputs
347+
246348 def requires (self ):
247349 """Process the inputs to generate the requirements."""
248- if self .inputs ():
249- requires = list (self .inputs ().keys ()) # pylint: disable=not-callable
250- for req in requires :
251- if req .dataset_df is None :
252- req .dataset_df = self .dataset_df
253- req .input_index_col = self .input_index_col
254- if req .result_path is None :
255- req .result_path = self .result_path
256- if req .nb_processes is None :
257- req .nb_processes = self .nb_processes
258- if req .redirect_stdout is None :
259- req .redirect_stdout = self .redirect_stdout
260- else :
261- requires = []
350+ requires = list (self .processed_inputs () or [])
262351 return requires + task_flatten (self .extra_requires ())
263352
264353 def extra_requires (self ):
@@ -364,9 +453,11 @@ def _get_dataset(self):
364453
365454 def _join_inputs (self , new_df ):
366455 """Get the inputs and join them to the dataset."""
367- if self .inputs ():
456+ if self .processed_inputs ():
368457 # Check inputs
369- with_mapping = self .check_inputs (self .inputs ()) # pylint: disable=not-callable
458+ with_mapping = self .check_inputs (
459+ self .processed_inputs ()
460+ ) # pylint: disable=not-callable
370461
371462 # Get the input targets and their DataFrames
372463 all_inputs = {
@@ -375,7 +466,7 @@ def _join_inputs(self, new_df):
375466 for target in task_flatten (i .output ())
376467 if isinstance (target , ReportTarget )
377468 ]
378- for i in task_flatten (self .inputs ().keys ())
469+ for i in task_flatten (self .processed_inputs ().keys ())
379470 }
380471 all_report_paths = {
381472 t : [r .path for r in reports ][0 ] for t , reports in all_inputs .items ()
@@ -396,7 +487,7 @@ def _join_inputs(self, new_df):
396487 # Filter columns in the DataFrames
397488 if with_mapping :
398489 # pylint: disable=not-callable
399- filtered_dfs = self .filter_columns (all_dfs , self .inputs ())
490+ filtered_dfs = self .filter_columns (all_dfs , self .processed_inputs ())
400491
401492 # Concatenate all DataFrames
402493 filtered_df = pd .concat (filtered_dfs , axis = 1 )
@@ -443,7 +534,7 @@ def _join_inputs(self, new_df):
443534 def run (self ):
444535 """The main process of the current task."""
445536 # Import the DataFrame(s)
446- if self .dataset_df is None and self .inputs () is None :
537+ if self .dataset_df is None and self .processed_inputs () is None :
447538 raise ValueError ("Either the 'dataset_df' parameter or a requirement must be provided." )
448539
449540 new_df = self ._get_dataset ()
0 commit comments