From c0374b6b35f036274d39bd594f6aeea1fb3ec057 Mon Sep 17 00:00:00 2001 From: Florian Pinault Date: Fri, 23 Feb 2024 16:08:32 +0000 Subject: [PATCH 1/2] implement date shifting --- .../create/functions/actions/constants.py | 69 ----- ecml_tools/create/group.py | 245 ------------------ ecml_tools/create/input.py | 146 +++++++++-- tests/create-shift.yaml | 62 +++++ 4 files changed, 193 insertions(+), 329 deletions(-) delete mode 100644 ecml_tools/create/group.py create mode 100644 tests/create-shift.yaml diff --git a/ecml_tools/create/functions/actions/constants.py b/ecml_tools/create/functions/actions/constants.py index 585290d..2ea043a 100644 --- a/ecml_tools/create/functions/actions/constants.py +++ b/ecml_tools/create/functions/actions/constants.py @@ -6,54 +6,8 @@ # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. # -from copy import deepcopy - from climetlab import load_source -from ecml_tools.create.utils import to_datetime_list - -DEBUG = True - - -def to_list(x): - if isinstance(x, (list, tuple)): - return x - return [x] - - -def get_template_field(request): - """Create a template request from the initial request, setting the date, time, - levtype and param fields.""" - template_request = { - "class": "ea", - "expver": "0001", - "type": "an", - "date": "20200101", - "time": "0000", - "levtype": "sfc", - "param": "2t", - } - for k in ["area", "grid"]: # is class needed? - if k in request: - template_request[k] = request[k] - template = load_source("mars", template_request) - assert len(template) == 1, (len(template), template_request) - return template - - -def normalise_time_to_hours(r): - r = deepcopy(r) - if "time" not in r: - return r - - times = [] - for t in to_list(r["time"]): - assert len(t) == 4, r - assert t.endswith("00"), r - times.append(int(t) // 100) - r["time"] = tuple(times) - return r - def constants(context, dates, template, param): context.trace("✅", f"load_source(constants, {template}, {param}") @@ -61,26 +15,3 @@ def constants(context, dates, template, param): execute = constants - -if __name__ == "__main__": - import yaml - - config = yaml.safe_load( - """ - class: ea - expver: '0001' - grid: 20.0/20.0 - levtype: sfc - # param: [10u, 10v, 2d, 2t, lsm, msl, sdor, skt, slor, sp, tcw, z] - number: [0, 1] - param: [cos_latitude] - """ - ) - dates = yaml.safe_load( - "[2022-12-30 18:00, 2022-12-31 00:00, 2022-12-31 06:00, 2022-12-31 12:00]" - ) - dates = to_datetime_list(dates) - - DEBUG = True - for f in constants(None, dates, **config): - print(f, f.to_numpy().mean()) diff --git a/ecml_tools/create/group.py b/ecml_tools/create/group.py deleted file mode 100644 index 10cdc43..0000000 --- a/ecml_tools/create/group.py +++ /dev/null @@ -1,245 +0,0 @@ -# (C) Copyright 2023 ECMWF. -# -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. -# -import datetime -import itertools -from functools import cached_property - -from .utils import to_datetime - - -class GroupByDays: - def __init__(self, days): - self.days = days - - def __call__(self, dt): - year = dt.year - days = (dt - datetime.datetime(year, 1, 1)).days - x = (year, days // self.days) - return x - - -class Group(list): - """Interface wrapper for List objects.""" - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - assert len(self) >= 1, self - assert all(isinstance(_, datetime.datetime) for _ in self), self - - def __repr__(self): - try: - content = ",".join([str(_.strftime("%Y-%m-%d:%H")) for _ in self]) - return f"Group({len(self)}, {content})" - except Exception: - return super().__repr__() - - -class BaseGroups: - def __repr__(self): - try: - content = "+".join([str(len(list(g))) for g in self.groups]) - print(content) - for g in self.groups: - assert isinstance(g[0], datetime.datetime), g[0] - print("val", self.values, self.n_groups) - return f"{self.__class__.__name__}({content}={len(self.values)})({self.n_groups} groups)" - except: # noqa - return f"{self.__class__.__name__}({len(self.values)} dates)" - - @cached_property - def values(self): - raise NotImplementedError() - - def intersect(self, dates): - if dates is None: - return self - # before creating GroupsIntersection - # we make sure that dates it's also a Groups Instance - if not isinstance(dates, Groups): - dates = build_groups(dates) - return GroupsIntersection(self, dates) - - def empty(self): - return len(self.values) == 0 - - @property - def frequency(self): - datetimes = self.values - freq = (datetimes[1] - datetimes[0]).total_seconds() / 3600 - assert round(freq) == freq, freq - assert int(freq) == freq, freq - frequency = int(freq) - return frequency - - -class Groups(BaseGroups): - def __init__(self, config): - # Assert config input is ad dict but not a nested dict - if not isinstance(config, dict): - raise ValueError(f"Config must be a dict. {config}") - for k, v in config.items(): - if isinstance(v, dict): - raise ValueError(f"Values can't be a dictionary. {k,v}") - - self._config = config - - @property - def groups(self): - # Return a list where each sublist contain the subgroups - # of values according to the grouper_key - return [ - Group(g) for _, g in itertools.groupby(self.values, key=self.grouper_key) - ] - - @property - def grouper_key(self): - group_by = self._config.get("group_by") - if isinstance(group_by, int) and group_by > 0: - return GroupByDays(group_by) - return { - None: lambda dt: 0, # only one group - 0: lambda dt: 0, # only one group - "monthly": lambda dt: (dt.year, dt.month), - "daily": lambda dt: (dt.year, dt.month, dt.day), - "weekly": lambda dt: (dt.weekday(),), - "MMDD": lambda dt: (dt.month, dt.day), - }[group_by] - - @cached_property - def n_groups(self): - return len(self.groups) - - -class ExpandGroups(Groups): - def __init__(self, config): - super().__init__(config) - - def _(x): - if isinstance(x, str): - return to_datetime(x) - return x - - self.values = [_(x) for x in self._config.get("values")] - - -class SingleGroup(Groups): - def __init__(self, group): - self.values = group - - @property - def groups(self): - return [Group(self.values)] - - -class DateStartStopGroups(Groups): - def __init__(self, config): - super().__init__(config) - - @cached_property - def start(self): - return self._get_date("start") - - @cached_property - def end(self): - return self._get_date("end") - - def _get_date(self, date_key): - date = self._config[date_key] - if isinstance(date, str): - try: - # Attempt to parse the date string with timestamp format - check_timestamp = datetime.datetime.strptime(date, "%Y-%m-%dT%H:%M:%S") - if check_timestamp: - return to_datetime(date) - except ValueError: - raise ValueError( - f"{date_key} must include timestamp not just date {date,type(date)}" - ) - elif type(date) == datetime.date: # noqa: E721 - raise ValueError( - f"{date_key} must include timestamp not just date {date,type(date)}" - ) - else: - return date - - def _validate_date_range(self): - assert ( - self.end >= self.start - ), "End date must be greater than or equal to start date." - - def _extract_frequency(self, frequency_str): - freq_ending = frequency_str.lower()[-1] - freq_mapping = {"h": int(frequency_str[:-1]), "d": int(frequency_str[:-1]) * 24} - try: - return freq_mapping[freq_ending] - except: # noqa: E722 - raise ValueError( - f"Frequency must be in hours or days (12h or 2d). {frequency_str}" - ) - - def _validate_frequency(self, freq, frequency_str): - if freq > 24 and freq % 24 != 0: - raise ValueError( - f"Frequency must be less than 24h or a multiple of 24h. {frequency_str}" - ) - - @cached_property - def step(self): - _frequency_str = self._config.get("frequency", "1h") - _freq = self._extract_frequency(_frequency_str) - self._validate_frequency(_freq, _frequency_str) - return datetime.timedelta(hours=_freq) - - @cached_property - def values(self): - x = self.start - dates = [] - while x <= self.end: - dates.append(x) - - x += self.step - assert isinstance(dates[0], datetime.datetime), dates[0] - return dates - - -class EmptyGroups(BaseGroups): - def __init__(self): - self.values = [] - self.groups = [] - - @property - def frequency(self): - return None - - -class GroupsIntersection(BaseGroups): - def __init__(self, a, b): - assert isinstance(a, Groups), a - assert isinstance(b, Groups), b - self.a = a - self.b = b - - @cached_property - def values(self): - return list(set(self.a.values) & set(self.b.values)) - - -def build_groups(config): - if isinstance(config, Group): - return SingleGroup(config) - - assert isinstance(config, dict), config - - if "values" in config: - return ExpandGroups(config) - - if "start" in config and "end" in config: - return DateStartStopGroups(config) - - raise NotImplementedError(config) diff --git a/ecml_tools/create/input.py b/ecml_tools/create/input.py index 463e89f..e3f293e 100644 --- a/ecml_tools/create/input.py +++ b/ecml_tools/create/input.py @@ -6,6 +6,7 @@ # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. # +import datetime import importlib import logging import time @@ -33,6 +34,38 @@ LOG = logging.getLogger(__name__) +def parse_function_name(name): + if "-" in name: + name, delta = name.split("-") + sign = -1 + + elif "+" in name: + name, delta = name.split("+") + sign = 1 + + else: + return name, None + + assert delta[-1] == "h", (name, delta) + delta = sign * int(delta[:-1]) + return name, delta + + +def time_delta_to_string(delta): + assert isinstance(delta, datetime.timedelta), delta + seconds = delta.total_seconds() + hours = int(seconds // 3600) + assert hours * 3600 == seconds, delta + hours = abs(hours) + + if seconds > 0: + return f"plus_{hours}h" + if seconds == 0: + return "" + if seconds < 0: + return f"minus_{hours}h" + + def import_function(name, kind): return importlib.import_module( f"..functions.{kind}.{name}", @@ -41,6 +74,7 @@ def import_function(name, kind): def is_function(name, kind): + name, delta = parse_function_name(name) try: import_function(name, kind) return True @@ -432,6 +466,81 @@ def __repr__(self): return super().__repr__(content) +class _DatesShiftAction(Action): + def __init__(self, context, action_path, _delta, **kwargs): + super().__init__(context, action_path, **kwargs) + + if isinstance(_delta, str): + if _delta[0] == "-": + _delta, sign = int(_delta[1:]), -1 + else: + _delta, sign = int(_delta), 1 + _delta = datetime.timedelta(hours=sign * _delta) + assert isinstance(_delta, int), _delta + _delta = datetime.timedelta(hours=_delta) + self.delta = _delta + + self.content = action_factory(kwargs, context, self.action_path + ["shift"]) + + def is_shift_action(cls, name): + return "-" in name or "+" in name + + @trace_select + def select(self, dates): + shifted_dates = [d + self.delta for d in dates] + result = self.content.select(shifted_dates) + return UnShiftResult(self.context, self.action_path, dates, result, action=self) + + def __repr__(self): + return super().__repr__(f"{self.delta}\n{self.content}") + + +class UnShiftResult(Result): + def __init__(self, context, action_path, dates, result, action): + super().__init__(context, action_path, dates) + # dates are the actual requested dates + # result does not have the same dates + self.action = action + self.result = result + + def _trace_datasource(self, *args, **kwargs): + return f"{self.action.delta}({shorten(self.dates)})" + + @cached_property + @assert_fieldset + @notify_result + @trace_datasource + def datasource(self): + from climetlab.indexing.fieldset import FieldArray + + class DateShiftedField: + def __init__(self, field, delta): + self.field = field + self.delta = delta + + def metadata(self, key): + value = self.field.metadata(key) + if key == "param": + return value + "_" + time_delta_to_string(self.delta) + if key == "valid_datetime": + dt = datetime.datetime.fromisoformat(value) + new_dt = dt - self.delta + new_value = new_dt.isoformat() + return new_value + if key in ["date", "time", "step", "hdate"]: + raise NotImplementedError( + f"metadata {key} not implemented when shifting dates" + ) + return value + + def __getattr__(self, name): + return getattr(self.field, name) + + ds = self.result.datasource + ds = FieldArray([DateShiftedField(fs, self.action.delta) for fs in ds]) + return ds + + class FunctionAction(Action): def __init__(self, context, action_path, _name, **kwargs): super().__init__(context, action_path, **kwargs) @@ -443,6 +552,7 @@ def select(self, dates): @property def function(self): + name, delta = parse_function_name(self.name) return import_function(self.name, "actions") def __repr__(self): @@ -657,26 +767,33 @@ def action_factory(config, context, action_path): config = deepcopy(config) key = list(config.keys())[0] - cls = dict( - concat=ConcatAction, - join=JoinAction, - # label=LabelAction, - pipe=PipeAction, - # source=SourceAction, - function=FunctionAction, - dates=DateAction, - # dependency=DependencyAction, - ).get(key) if isinstance(config[key], list): args, kwargs = config[key], {} - if isinstance(config[key], dict): args, kwargs = [], config[key] + if "-" in key or "+" in key: + new_key, delta = parse_function_name(key) + new_config = dict(_dates_shift={"_delta": delta, new_key: config[key]}) + return action_factory(new_config, context, action_path) + + else: + cls = dict( + _dates_shift=_DatesShiftAction, + concat=ConcatAction, + join=JoinAction, + # label=LabelAction, + pipe=PipeAction, + # source=SourceAction, + function=FunctionAction, + dates=DateAction, + # dependency=DependencyAction, + ).get(key) + if cls is None: if not is_function(key, "actions"): - raise ValueError(f"Unknown action {key}") + raise ValueError(f"Unknown action '{key}' in {config}") cls = FunctionAction args = [key] + args @@ -714,9 +831,8 @@ def step_factory(config, context, action_path, previous_step): class FunctionContext: - """A FunctionContext is passed to all functions, it will be used to - pass information to the functions from the other actions and steps and results. - """ + """A FunctionContext is passed to all functions, it will be used to pass information + to the functions from the other actions and steps and results.""" def __init__(self, owner): self.owner = owner diff --git a/tests/create-shift.yaml b/tests/create-shift.yaml new file mode 100644 index 0000000..6cc4b3c --- /dev/null +++ b/tests/create-shift.yaml @@ -0,0 +1,62 @@ +description: "develop version of the dataset for a few days and a few variables, once data on mars is cached it should take a few seconds to generate the dataset" +dataset_status: testing +purpose: aifs +name: test-small +config_format_version: 2 + +common: + mars_request: &mars_request + expver: "0001" + class: ea + grid: 20./20. + +dates: + start: 2020-12-30 00:00:00 + end: 2021-01-03 12:00:00 + frequency: 12h + group_by: monthly + +input: + join: + - mars: + <<: *mars_request + param: [2t] + levtype: sfc + stream: oper + type: an + + - constants: + template: ${input.join.0.mars} + param: + - insolation + + + - constants-3h: + template: ${input.join.0.mars} + param: + - insolation + - constants+6h: + template: ${input.join.0.mars} + param: + - insolation + + - _dates_shift: # private, do not use directly + #_delta: 6 + _delta: -25 + constants: + template: ${input.join.0.mars} + param: + - insolation + +output: + chunking: { dates: 1, ensembles: 1 } + dtype: float32 + flatten_grid: True + order_by: + - valid_datetime + - param_level + - number + statistics: param_level + statistics_end: 2021 + remapping: &remapping + param_level: "{param}_{levelist}" From df2df0c69ea06282528560d9a193f37fd330ef31 Mon Sep 17 00:00:00 2001 From: Florian Pinault Date: Mon, 26 Feb 2024 10:43:13 +0000 Subject: [PATCH 2/2] clean date shift action --- ecml_tools/create/input.py | 50 +++++++++++++++----------------------- tests/create-shift.yaml | 15 ++---------- 2 files changed, 22 insertions(+), 43 deletions(-) diff --git a/ecml_tools/create/input.py b/ecml_tools/create/input.py index e3f293e..4f2478d 100644 --- a/ecml_tools/create/input.py +++ b/ecml_tools/create/input.py @@ -466,25 +466,22 @@ def __repr__(self): return super().__repr__(content) -class _DatesShiftAction(Action): - def __init__(self, context, action_path, _delta, **kwargs): +class DateShiftAction(Action): + def __init__(self, context, action_path, delta, **kwargs): super().__init__(context, action_path, **kwargs) - if isinstance(_delta, str): - if _delta[0] == "-": - _delta, sign = int(_delta[1:]), -1 + if isinstance(delta, str): + if delta[0] == "-": + delta, sign = int(delta[1:]), -1 else: - _delta, sign = int(_delta), 1 - _delta = datetime.timedelta(hours=sign * _delta) - assert isinstance(_delta, int), _delta - _delta = datetime.timedelta(hours=_delta) - self.delta = _delta + delta, sign = int(delta), 1 + delta = datetime.timedelta(hours=sign * delta) + assert isinstance(delta, int), delta + delta = datetime.timedelta(hours=delta) + self.delta = delta self.content = action_factory(kwargs, context, self.action_path + ["shift"]) - def is_shift_action(cls, name): - return "-" in name or "+" in name - @trace_select def select(self, dates): shifted_dates = [d + self.delta for d in dates] @@ -773,23 +770,16 @@ def action_factory(config, context, action_path): if isinstance(config[key], dict): args, kwargs = [], config[key] - if "-" in key or "+" in key: - new_key, delta = parse_function_name(key) - new_config = dict(_dates_shift={"_delta": delta, new_key: config[key]}) - return action_factory(new_config, context, action_path) - - else: - cls = dict( - _dates_shift=_DatesShiftAction, - concat=ConcatAction, - join=JoinAction, - # label=LabelAction, - pipe=PipeAction, - # source=SourceAction, - function=FunctionAction, - dates=DateAction, - # dependency=DependencyAction, - ).get(key) + cls = dict( + date_shift=DateShiftAction, + # date_filter=DateFilterAction, + # include=IncludeAction, + concat=ConcatAction, + join=JoinAction, + pipe=PipeAction, + function=FunctionAction, + dates=DateAction, + ).get(key) if cls is None: if not is_function(key, "actions"): diff --git a/tests/create-shift.yaml b/tests/create-shift.yaml index 6cc4b3c..72dc620 100644 --- a/tests/create-shift.yaml +++ b/tests/create-shift.yaml @@ -30,19 +30,8 @@ input: param: - insolation - - - constants-3h: - template: ${input.join.0.mars} - param: - - insolation - - constants+6h: - template: ${input.join.0.mars} - param: - - insolation - - - _dates_shift: # private, do not use directly - #_delta: 6 - _delta: -25 + - date_shift: + delta: -25 constants: template: ${input.join.0.mars} param: