Skip to content
This repository has been archived by the owner on Jan 10, 2025. It is now read-only.

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
some renaming : include(list) -> data_sources(dict), move group_by in…
Browse files Browse the repository at this point in the history
… options:
floriankrb committed Mar 11, 2024

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
1 parent a098768 commit 8b967e8
Showing 4 changed files with 74 additions and 55 deletions.
5 changes: 5 additions & 0 deletions ecml_tools/create/config.py
Original file line number Diff line number Diff line change
@@ -169,6 +169,8 @@ def __init__(self, config, *args, **kwargs):
if "loop" in self:
raise ValueError(f"Do not use 'loop'. Use dates instead. {list(self.keys())}")

self.options = self.get("options", {})

if "licence" not in self:
self.licence = "unknown"
print(f"❗ Setting licence={self.licence} because it was not provided.")
@@ -198,6 +200,9 @@ def __init__(self, config, *args, **kwargs):
assert "flatten_grid" in self.output, self.output
assert "statistics" in self.output

if "group_by" in self.options:
self.dates["group_by"] = self.options.group_by

def get_serialisable_dict(self):
return _prepare_serialisation(self)

100 changes: 55 additions & 45 deletions ecml_tools/create/input.py
Original file line number Diff line number Diff line change
@@ -554,16 +554,6 @@ def _trace_select(self, dates):
return f"{self.name}({shorten(dates)})"


class ActionWithList(Action):
def __init__(self, context, action_path, *configs):
super().__init__(context, action_path, *configs)
self.actions = [action_factory(c, context, action_path + [str(i)]) for i, c in enumerate(configs)]

def __repr__(self):
content = "\n".join([str(i) for i in self.actions])
return super().__repr__(content)


class PipeAction(Action):
def __init__(self, context, action_path, *configs):
super().__init__(context, action_path, *configs)
@@ -694,40 +684,52 @@ def __repr__(self):
return super().__repr__(content)


class IncludeResult(Result):
def __init__(self, context, action_path, dates, result, results):
class DataSourcesResult(Result):
def __init__(self, context, action_path, dates, input_result, sources_results):
super().__init__(context, action_path, dates)
# result is the content of the include
self.result = result
# results is the list of the included results
self.results = results
# result is the main input result
self.input_result = input_result
# sources_results is the list of the sources_results
self.sources_results = sources_results

@cached_property
def datasource(self):
for i in self.results:
# for each include trigger the datasource to be computed
# and saved in context but drop it
i.datasource
# then return the content of the result
for i in self.sources_results:
# for each result trigger the datasource to be computed
# and saved in context
self.context.notify_result(i.action_path[:-1], i.datasource)
# then return the input result
# which can use the datasources of the included results
return self.result.datasource
return self.input_result.datasource


class IncludeAction(ActionWithList):
def __init__(self, context, action_path, includes, content):
super().__init__(context, ["include"], *includes)
self.content = action_factory(content, context, ["input"])
class DataSourcesAction(Action):
def __init__(self, context, action_path, sources, input):
super().__init__(context, ["data_sources"], *sources)
if isinstance(sources, dict):
configs = [(str(k), c) for k, c in sources.items()]
elif isinstance(sources, list):
configs = [(str(i), c) for i, c in enumerate(sources)]
else:
raise ValueError(f"Invalid data_sources, expecting list or dict, got {type(sources)}: {sources}")

self.sources = [action_factory(config, context, ["data_sources"] + [a_path]) for a_path, config in configs]
self.input = action_factory(input, context, ["input"])

def select(self, dates):
results = [a.select(dates) for a in self.actions]
return IncludeResult(
sources_results = [a.select(dates) for a in self.sources]
return DataSourcesResult(
self.context,
self.action_path,
dates,
self.content.select(dates),
results,
self.input.select(dates),
sources_results,
)

def __repr__(self):
content = "\n".join([str(i) for i in self.sources])
return super().__repr__(content)


class ConcatAction(Action):
def __init__(self, context, action_path, *configs):
@@ -761,7 +763,15 @@ def select(self, dates):
return ConcatResult(self.context, self.action_path, dates, results)


class JoinAction(ActionWithList):
class JoinAction(Action):
def __init__(self, context, action_path, *configs):
super().__init__(context, action_path, *configs)
self.actions = [action_factory(c, context, action_path + [str(i)]) for i, c in enumerate(configs)]

def __repr__(self):
content = "\n".join([str(i) for i in self.actions])
return super().__repr__(content)

@trace_select
def select(self, dates):
results = [a.select(dates) for a in self.actions]
@@ -783,15 +793,15 @@ def action_factory(config, context, action_path):
if isinstance(config[key], dict):
args, kwargs = [], config[key]

cls = dict(
# date_shift=DateShiftAction,
# date_filter=DateFilterAction,
include=IncludeAction,
concat=ConcatAction,
join=JoinAction,
pipe=PipeAction,
function=FunctionAction,
).get(key)
cls = {
# "date_shift": DateShiftAction,
# "date_filter": DateFilterAction,
"data_sources": DataSourcesAction,
"concat": ConcatAction,
"join": JoinAction,
"pipe": PipeAction,
"function": FunctionAction,
}.get(key)

if cls is None:
if not is_function(key, "actions"):
@@ -852,15 +862,15 @@ def __init__(self, /, order_by, flatten_grid, remapping):


class InputBuilder:
def __init__(self, config, include, **kwargs):
def __init__(self, config, data_sources, **kwargs):
self.kwargs = kwargs

config = deepcopy(config)
if include:
if data_sources:
config = dict(
include=dict(
includes=include,
content=config,
data_sources=dict(
sources=data_sources,
input=config,
)
)
self.config = config
2 changes: 1 addition & 1 deletion ecml_tools/create/loaders.py
Original file line number Diff line number Diff line change
@@ -78,7 +78,7 @@ def build_input(self):

builder = build_input(
self.main_config.input,
include=self.main_config.get("include", {}),
data_sources=self.main_config.get("data_sources", {}),
order_by=self.output.order_by,
flatten_grid=self.output.flatten_grid,
remapping=build_remapping(self.output.remapping),
22 changes: 13 additions & 9 deletions tests/create-perturbations.yaml
Original file line number Diff line number Diff line change
@@ -34,44 +34,48 @@ dates:
start: 2020-12-30 00:00:00
end: 2021-01-03 12:00:00
frequency: 12h

options:
group_by: monthly

include: # This "include" will be renamed
- join:
data_sources:
ensembles:
join:
- mars:
<<: *ensembles
<<: *common
- accumulations:
<<: *ensembles
<<: *common_acc
- join:
mean:
join:
- mars:
<<: *mean
<<: *common
- accumulations:
<<: *mean
<<: *common_acc
- join:
center:
join:
- mars:
<<: *center
<<: *common
- accumulations:
<<: *center
<<: *common_acc

input:
ensemble_perturbations:
# the ensemble data which has one additional dimension
ensembles: ${include.0.join}
ensembles: ${data_sources.ensembles}
# the previous center of the data
mean: ${include.1.join}
mean: ${data_sources.mean}
# the new center of the data
center: ${include.2.join}
center: ${data_sources.center}

output:
chunking: { dates: 1 }
dtype: float32
flatten_grid: True
order_by:
- valid_datetime
- param_level

0 comments on commit 8b967e8

Please sign in to comment.