diff --git a/ecml_tools/create/config.py b/ecml_tools/create/config.py index 7fddb22..fe0f61d 100644 --- a/ecml_tools/create/config.py +++ b/ecml_tools/create/config.py @@ -169,6 +169,8 @@ def __init__(self, config, *args, **kwargs): if "loop" in self: raise ValueError(f"Do not use 'loop'. Use dates instead. {list(self.keys())}") + self.options = self.get("options", {}) + if "licence" not in self: self.licence = "unknown" print(f"❗ Setting licence={self.licence} because it was not provided.") @@ -198,6 +200,9 @@ def __init__(self, config, *args, **kwargs): assert "flatten_grid" in self.output, self.output assert "statistics" in self.output + if "group_by" in self.options: + self.dates["group_by"] = self.options.group_by + def get_serialisable_dict(self): return _prepare_serialisation(self) diff --git a/ecml_tools/create/input.py b/ecml_tools/create/input.py index 638ac57..5190480 100644 --- a/ecml_tools/create/input.py +++ b/ecml_tools/create/input.py @@ -554,16 +554,6 @@ def _trace_select(self, dates): return f"{self.name}({shorten(dates)})" -class ActionWithList(Action): - def __init__(self, context, action_path, *configs): - super().__init__(context, action_path, *configs) - self.actions = [action_factory(c, context, action_path + [str(i)]) for i, c in enumerate(configs)] - - def __repr__(self): - content = "\n".join([str(i) for i in self.actions]) - return super().__repr__(content) - - class PipeAction(Action): def __init__(self, context, action_path, *configs): super().__init__(context, action_path, *configs) @@ -694,40 +684,52 @@ def __repr__(self): return super().__repr__(content) -class IncludeResult(Result): - def __init__(self, context, action_path, dates, result, results): +class DataSourcesResult(Result): + def __init__(self, context, action_path, dates, input_result, sources_results): super().__init__(context, action_path, dates) - # result is the content of the include - self.result = result - # results is the list of the included results - self.results = results + # result is the main input result + self.input_result = input_result + # sources_results is the list of the sources_results + self.sources_results = sources_results @cached_property def datasource(self): - for i in self.results: - # for each include trigger the datasource to be computed - # and saved in context but drop it - i.datasource - # then return the content of the result + for i in self.sources_results: + # for each result trigger the datasource to be computed + # and saved in context + self.context.notify_result(i.action_path[:-1], i.datasource) + # then return the input result # which can use the datasources of the included results - return self.result.datasource + return self.input_result.datasource -class IncludeAction(ActionWithList): - def __init__(self, context, action_path, includes, content): - super().__init__(context, ["include"], *includes) - self.content = action_factory(content, context, ["input"]) +class DataSourcesAction(Action): + def __init__(self, context, action_path, sources, input): + super().__init__(context, ["data_sources"], *sources) + if isinstance(sources, dict): + configs = [(str(k), c) for k, c in sources.items()] + elif isinstance(sources, list): + configs = [(str(i), c) for i, c in enumerate(sources)] + else: + raise ValueError(f"Invalid data_sources, expecting list or dict, got {type(sources)}: {sources}") + + self.sources = [action_factory(config, context, ["data_sources"] + [a_path]) for a_path, config in configs] + self.input = action_factory(input, context, ["input"]) def select(self, dates): - results = [a.select(dates) for a in self.actions] - return IncludeResult( + sources_results = [a.select(dates) for a in self.sources] + return DataSourcesResult( self.context, self.action_path, dates, - self.content.select(dates), - results, + self.input.select(dates), + sources_results, ) + def __repr__(self): + content = "\n".join([str(i) for i in self.sources]) + return super().__repr__(content) + class ConcatAction(Action): def __init__(self, context, action_path, *configs): @@ -761,7 +763,15 @@ def select(self, dates): return ConcatResult(self.context, self.action_path, dates, results) -class JoinAction(ActionWithList): +class JoinAction(Action): + def __init__(self, context, action_path, *configs): + super().__init__(context, action_path, *configs) + self.actions = [action_factory(c, context, action_path + [str(i)]) for i, c in enumerate(configs)] + + def __repr__(self): + content = "\n".join([str(i) for i in self.actions]) + return super().__repr__(content) + @trace_select def select(self, dates): results = [a.select(dates) for a in self.actions] @@ -783,15 +793,15 @@ def action_factory(config, context, action_path): if isinstance(config[key], dict): args, kwargs = [], config[key] - cls = dict( - # date_shift=DateShiftAction, - # date_filter=DateFilterAction, - include=IncludeAction, - concat=ConcatAction, - join=JoinAction, - pipe=PipeAction, - function=FunctionAction, - ).get(key) + cls = { + # "date_shift": DateShiftAction, + # "date_filter": DateFilterAction, + "data_sources": DataSourcesAction, + "concat": ConcatAction, + "join": JoinAction, + "pipe": PipeAction, + "function": FunctionAction, + }.get(key) if cls is None: if not is_function(key, "actions"): @@ -852,15 +862,15 @@ def __init__(self, /, order_by, flatten_grid, remapping): class InputBuilder: - def __init__(self, config, include, **kwargs): + def __init__(self, config, data_sources, **kwargs): self.kwargs = kwargs config = deepcopy(config) - if include: + if data_sources: config = dict( - include=dict( - includes=include, - content=config, + data_sources=dict( + sources=data_sources, + input=config, ) ) self.config = config diff --git a/ecml_tools/create/loaders.py b/ecml_tools/create/loaders.py index 0fa7ed3..ad45b2b 100644 --- a/ecml_tools/create/loaders.py +++ b/ecml_tools/create/loaders.py @@ -78,7 +78,7 @@ def build_input(self): builder = build_input( self.main_config.input, - include=self.main_config.get("include", {}), + data_sources=self.main_config.get("data_sources", {}), order_by=self.output.order_by, flatten_grid=self.output.flatten_grid, remapping=build_remapping(self.output.remapping), diff --git a/tests/create-perturbations.yaml b/tests/create-perturbations.yaml index c4873f4..29c177a 100644 --- a/tests/create-perturbations.yaml +++ b/tests/create-perturbations.yaml @@ -34,44 +34,48 @@ dates: start: 2020-12-30 00:00:00 end: 2021-01-03 12:00:00 frequency: 12h + +options: group_by: monthly -include: # This "include" will be renamed - - join: +data_sources: + ensembles: + join: - mars: <<: *ensembles <<: *common - accumulations: <<: *ensembles <<: *common_acc - - join: + mean: + join: - mars: <<: *mean <<: *common - accumulations: <<: *mean <<: *common_acc - - join: + center: + join: - mars: <<: *center <<: *common - accumulations: <<: *center <<: *common_acc - + input: ensemble_perturbations: # the ensemble data which has one additional dimension - ensembles: ${include.0.join} + ensembles: ${data_sources.ensembles} # the previous center of the data - mean: ${include.1.join} + mean: ${data_sources.mean} # the new center of the data - center: ${include.2.join} + center: ${data_sources.center} output: chunking: { dates: 1 } dtype: float32 - flatten_grid: True order_by: - valid_datetime - param_level