Skip to content
This repository has been archived by the owner on Jan 10, 2025. It is now read-only.

Commit

Permalink
refactoring. pip, join, concat ok.
Browse files Browse the repository at this point in the history
  • Loading branch information
floriankrb committed Jan 29, 2024
1 parent 04dd14d commit af202a5
Show file tree
Hide file tree
Showing 4 changed files with 267 additions and 84 deletions.
257 changes: 182 additions & 75 deletions ecml_tools/create/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,17 @@ def coords(self):

class Action:
def __init__(self, context, /, *args, **kwargs):
if "args" in kwargs and "kwargs" in kwargs:
"""We have:
args = []
kwargs = {args: [...], kwargs: {...}}
move the content of kwargs to args and kwargs.
"""
assert len(kwargs) == 2, (args, kwargs)
assert not args, (args, kwargs)
args = kwargs.pop("args")
kwargs = kwargs.pop("kwargs")

assert isinstance(context, Context), type(context)
self.context = context
self.kwargs = kwargs
Expand All @@ -212,18 +223,18 @@ def _short_str(cls, x):
return x
return x[:1000] + "..."

def __repr__(self, *args, __indent__="\n", **kwargs):
def __repr__(self, *args, _indent_="\n", _inline_="", **kwargs):
more = ",".join([str(a)[:5000] for a in args])
more += ",".join([f"{k}={v}"[:5000] for k, v in kwargs.items()])

more = more[:5000]
txt = f"{self.__class__.__name__}:{__indent__}{more}"
if __indent__:
txt = f"{self.__class__.__name__}: {_inline_}{_indent_}{more}"
if _indent_:
txt = txt.replace("\n", "\n ")
return txt

def select(self, dates, **kwargs):
return result_factory(self, dates=dates, **kwargs)
self._raise_not_implemented()

def _raise_not_implemented(self):
raise NotImplementedError(f"Not implemented in {self.__class__.__name__}")
Expand Down Expand Up @@ -269,7 +280,7 @@ def get_cube(self):

return cube

def __repr__(self, *args, __indent__="\n", **kwargs):
def __repr__(self, *args, _indent_="\n", **kwargs):
more = ",".join([str(a)[:5000] for a in args])
more += ",".join([f"{k}={v}"[:5000] for k, v in kwargs.items()])

Expand All @@ -283,8 +294,8 @@ def __repr__(self, *args, __indent__="\n", **kwargs):
dates += ")"

more = more[:5000]
txt = f"{self.__class__.__name__}:{dates}{__indent__}{more}"
if __indent__:
txt = f"{self.__class__.__name__}:{dates}{_indent_}{more}"
if _indent_:
txt = txt.replace("\n", "\n ")
return txt

Expand Down Expand Up @@ -331,11 +342,6 @@ def __init__(self, context, dates, action, previous_sibling=None):
_args = self.action.args
_kwargs = self.action.kwargs

if "@" in _kwargs:
name = _kwargs.pop("@")
context.register_reference(name, self)

print("✅dates in source result:", dates)
vars = ReferencesSolver(context, dates)

self.args = substitute(_args, vars)
Expand Down Expand Up @@ -367,44 +373,45 @@ def datasource(self):
assert_is_fieldset(ds), i
return ds

@property
def variables(self):
variables = super().variables
print("🆗variables in JoinResult:", variables)
return variables

def __repr__(self):
content = "\n".join([str(i) for i in self.results])
return super().__repr__(content)


class SourceAction(Action):
def __init__(self, context, *args, **kwargs):
if "args" in kwargs and "kwargs" in kwargs:
assert len(kwargs) == 2, (args, kwargs)
assert not args, (args, kwargs)
args = kwargs.pop("args")
kwargs = kwargs.pop("kwargs")
print("sourceaction", args, kwargs)
class LabelAction(Action):
def __init__(self, context, name, **kwargs):
super().__init__(context)
if len(kwargs) != 1:
raise ValueError(f"Invalid kwargs for label : {kwargs}")
self.name = name
self.content = action_factory(kwargs, context)

super().__init__(context, *args, **kwargs)
def select(self, dates):
result = self.content.select(dates)
self.context.register_reference(self.name, result)
return result

def __repr__(self):
return super().__repr__(_inline_=self.name, _indent_=" ")


class SourceAction(Action):
def __repr__(self):
content = ""
content += ",".join([self._short_str(a) for a in self.args])
content += " ".join(
[self._short_str(f"{k}={v}") for k, v in self.kwargs.items()]
)
content = self._short_str(content)
return super().__repr__(content)
return super().__repr__(_inline_=content, _indent_=" ")

def select(self, dates):
return SourceResult(self.context, dates, action=self)


class ConcatResult(Result):
def __init__(self, context, dates, results):
super().__init__(context, dates)
def __init__(self, context, results):
super().__init__(context, dates=None)
self.results = [r for r in results if not r.empty]

@property
Expand All @@ -417,6 +424,7 @@ def datasource(self):

@property
def variables(self):
"""Check that all the results objects have the same variables"""
variables = None
for f in self.results:
if f.empty:
Expand All @@ -429,6 +437,7 @@ def variables(self):

@property
def dates(self):
"""Merge the dates of all the results objects"""
dates = []
for i in self.results:
d = i.dates
Expand All @@ -448,6 +457,8 @@ def __repr__(self):


class ActionWithList(Action):
result_class = None

def __init__(self, context, *configs):
super().__init__(context, *configs)
self.actions = [action_factory(c, context) for c in configs]
Expand All @@ -457,10 +468,83 @@ def __repr__(self):
return super().__repr__(content)


class PipeAction(Action):
def __init__(self, context, *configs):
super().__init__(context, *configs)
current = action_factory(configs[0], context)
for c in configs[1:]:
current = step_factory(c, context, _upstream_action=current)
self.content = current

def select(self, dates):
return self.content.select(dates)

def __repr__(self):
return super().__repr__(self.content)


class StepResult(Result):
def __init__(self, upstream, context, dates, action):
super().__init__(context, dates)
assert isinstance(upstream, Result), type(upstream)
self.content = upstream
self.action = action

@property
def datasource(self):
return self.content.datasource


class StepAction(Action):
result_class = None

def __init__(self, context, _upstream_action, **kwargs):
super().__init__(context, **kwargs)
self.content = _upstream_action

def select(self, dates):
return self.result_class(
self.content.select(dates),
self.context,
dates,
self,
)

def __repr__(self):
return super().__repr__(self.content, _inline_=str(self.kwargs))


class FilterResult(StepResult):
@property
def datasource(self):
ds = self.content.datasource
assert_is_fieldset(ds)
ds = ds.sel(**self.action.kwargs)
assert_is_fieldset(ds)
return ds


class FilterAction(StepAction):
result_class = FilterResult


# class RenameResult(StepResult):
# @property
# def datasource(self):
# ds = self.content.datasource
# assert_is_fieldset(ds)
# ds = ds.rename(**self.action.kwargs)
# assert_is_fieldset(ds)
# return ds
#
#
# class RenameAction(StepAction):
# result_class = RenameResult


class ConcatAction(ActionWithList):
def select(self, dates):
print("🆗dates in ConcatAction select:", dates)
return ConcatResult(self.context, None, [a.select(dates) for a in self.actions])
return ConcatResult(self.context, [a.select(dates) for a in self.actions])


class JoinAction(ActionWithList):
Expand All @@ -487,8 +571,6 @@ def select(self, dates):
newdates = self._dates.intersect(dates)
if newdates.empty():
return EmptyResult(self.context, dates=newdates)

print("🆗dates in DateAction select:", newdates)
return self.content.select(newdates)

def __repr__(self):
Expand All @@ -509,41 +591,34 @@ def merge_dicts(a, b):
return deepcopy(b)


class Context:
def __init__(self, loader=None):
self.order_by = loader.output.order_by
self.flatten_grid = loader.output.flatten_grid
self.remapping = build_remapping(loader.output.remapping)
def action_factory(config, context):
assert isinstance(context, Context), (type, context)
if not isinstance(config, dict):
raise ValueError(f"Invalid input config {config}")

self.references = {}
config = deepcopy(config)
assert len(config) == 1, config

def register_reference(self, name, obj):
assert isinstance(obj, Result), type(obj)
if name in self.references:
raise ValueError(f"Duplicate reference {name}")
self.references[name] = obj
key = list(config.keys())[0]
cls = dict(
concat=ConcatAction,
join=JoinAction,
label=LabelAction,
pipe=PipeAction,
source=SourceAction,
dates=DateAction,
)[key]

def find_reference(self, name):
if name in self.references:
return self.references[name]
# It can happend that the required name is not yet registered,
# even if it is defined in the config.
# Handling this case implies implementing a lazy inheritance resolution
# and would complexify the code. This is not implemented.
raise ValueError(f"Cannot find reference {name}")
if isinstance(config[key], list):
args, kwargs = config[key], {}

# def resolve_inheritance(self, config):
# if not isinstance(config, dict):
# return config
# if "inherit" in config:
# config = deepcopy(config)
# inherit = config.pop("inherit")
# other = self.find_inheritance(inherit)
# return merge_dicts(other._config, config)
# return config
if isinstance(config[key], dict):
args, kwargs = [], config[key]

return cls(context, *args, **kwargs)

def action_factory(config, context):

def step_factory(config, context, _upstream_action):
assert isinstance(context, Context), (type, context)
if not isinstance(config, dict):
raise ValueError(f"Invalid input config {config}")
Expand All @@ -553,28 +628,60 @@ def action_factory(config, context):

key = list(config.keys())[0]
cls = dict(
concat=ConcatAction,
join=JoinAction,
# pipe=PipeAction,
source=SourceAction,
# remapping=RemappingAction,
dates=DateAction,
filter=FilterAction,
# rename=RenameAction,
# remapping=RemappingAction,
)[key]

if isinstance(config[key], list):
args, kwargs = config[key], {}

if isinstance(config[key], dict):
args, kwargs = [], config[key]

if "_upstream_action" in kwargs:
raise ValueError(f"Reserverd keyword '_upsream_action' in {config}")
kwargs["_upstream_action"] = _upstream_action

return cls(context, *args, **kwargs)


def result_factory(spec, dates, **kwargs):
return spec.result_class(spec, dates, **kwargs)
class Context:
def __init__(self, loader=None):
self.order_by = loader.output.order_by
self.flatten_grid = loader.output.flatten_grid
self.remapping = build_remapping(loader.output.remapping)

self.references = {}

def register_reference(self, name, obj):
assert isinstance(obj, Result), type(obj)
if name in self.references:
raise ValueError(f"Duplicate reference {name}")
self.references[name] = obj

def find_reference(self, name):
if name in self.references:
return self.references[name]
# It can happend that the required name is not yet registered,
# even if it is defined in the config.
# Handling this case implies implementing a lazy inheritance resolution
# and would complexify the code. This is not implemented.
raise ValueError(f"Cannot find reference {name}")


class InputBuilder:
def __init__(self, config, loader):
self.loader = loader
self.config = config

@property
def _action(self):
context = Context(loader=self.loader)
return action_factory(self.config, context)

def select(self, dates):
return self._action.select(dates)

def build_input(config, loader):
context = Context(loader=loader)
print(config)

return action_factory(config, context)
build_input = InputBuilder
Loading

0 comments on commit af202a5

Please sign in to comment.