diff --git a/src/earthkit/regrid/backends/__init__.py b/src/earthkit/regrid/backends/__init__.py new file mode 100644 index 0000000..8e3382b --- /dev/null +++ b/src/earthkit/regrid/backends/__init__.py @@ -0,0 +1,210 @@ +# (C) Copyright 2023 ECMWF. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +# + +import threading +from abc import ABCMeta +from abc import abstractmethod +from collections import namedtuple + +from earthkit.regrid.utils.config import CONFIG + +BackendKey = namedtuple("BackendKey", ["name", "path"]) + + +class Backend(metaclass=ABCMeta): + @abstractmethod + def interpolate(self, values, in_grid, out_grid, method, **kwargs): + pass + + +class BackendOrder: + DEFAULT_ORDER = ["local-matrix", "plugins", "remote-matrix", "system-matrix", "mir"] + BUILT_IN = {"local-matrix", "remote-matrix", "system-matrix", "mir"} + UNIQUE = {"system-matrix", "mir"} + + def select(self, order, names, backends): + if isinstance(names, str): + if names in self.UNIQUE: + return [b for k, b in backends.items() if k.name == names] + + r = [] + if names: + _order = list(names) + else: + _order = order or self.DEFAULT_ORDER + + print("names", names) + print("order", _order) + + for m in _order: + print("m", m) + for k, b in backends.items(): + print(" k", k) + if k.name == m or (m == "plugins" and k.name not in self.BUILT_IN): + print(" -> b", b) + r.append(b) + + return r + + +BACKEND_ORDER = BackendOrder() + + +class BackendManager: + BACKENDS = {} + + def __init__(self, *args, policy="all", **kwargs): + self.order = [] + # self._has_settings = False + self.lock = threading.Lock() + # self.update(*args, **kwargs) + + # # initialise the backend list + # self.BACKENDS.update(self.local()) + # self.BACKENDS.update(self.remote()) + # self.BACKENDS.update(self.mir()) + self.update() + + print("BACKENDS", self.BACKENDS) + + def update(self): + with self.lock: + self.BACKENDS.clear() + self.BACKENDS.update(self._local()) + self.BACKENDS.update(self._remote()) + self.BACKENDS.update(self._mir()) + + from earthkit.regrid.utils.config import CONFIG + + self.order = CONFIG.get("backend-order", []) or [] + + def backends(self, backend=None): + with self.lock: + return BACKEND_ORDER.select(self.order, backend, self.BACKENDS) + + def _local(self): + from earthkit.regrid.utils.config import CONFIG + + from .matrix import LocalMatrixBackend + + dirs = CONFIG.get("local-matrix-directories", []) + if dirs is None: + dirs = [] + if isinstance(dirs, str): + dirs = [dirs] + + r = {} + for d in dirs: + r[BackendKey("local-matrix", d)] = LocalMatrixBackend(d) + return r + + def _remote(self): + from earthkit.regrid.utils.config import CONFIG + + from .matrix import RemoteMatrixBackend + from .matrix import SystemRemoteMatrixBackend + + dirs = CONFIG.get("remote-matrix-directories", []) + if dirs is None: + dirs = [] + if isinstance(dirs, str): + dirs = [dirs] + + r = {} + for d in dirs: + r[BackendKey("remote-matrix", d)] = RemoteMatrixBackend(d) + + r[BackendKey("system-matrix", None)] = SystemRemoteMatrixBackend() + + return r + + def _mir(self): + from .mir import MirBackend + + return {BackendKey("mir", None): MirBackend()} + + +MANAGER = BackendManager() + +# def add_matrix_source(path): +# global DB_LIST +# for item in DB_LIST[1:]: +# if item.matrix_source() == path: +# return item +# db = MatrixDb.from_path(path) +# DB_LIST.append(db)s +# return db + + +# def find(*args, matrix_source=None, **kwargs): +# if matrix_source is None: +# return SYS_DB.find(*args, **kwargs) +# else: +# db = add_matrix_source(matrix_source) +# return db.find(*args, **kwargs) + + +# class Backend(metaclass=ABCMeta): +# @abstractmethod +# def interpolate(self, values, in_grid, out_grid, method, **kwargs): +# pass + + +# class BackendLoader: +# kind = "backend" + +# def load_module(self, module): +# return import_module(module, package=__name__).backend + +# def load_entry(self, entry): +# entry = entry.load() +# if callable(entry): +# return entry +# return entry.backend + +# def load_remote(self, name): +# return None + + +# class BackendMaker: +# BACKENDS = {} + +# def __init__(self): +# # Preregister the most important backends +# from .mir import MirBackend +# from .matrix import LocalMatrixBackend +# from .matrix import RemoteMatrixBackend + +# self.BACKENDS["mir"] = MirBackend +# self.BACKENDS["local"] = LocalMatrixBackend +# self.BACKENDS["remote"] = RemoteMatrixBackend + +# def __call__(self, name, *args, **kwargs): +# loader = BackendLoader() + +# if name in self.BACKENDS: +# klass = self.BACKENDS[name] +# else: +# klass = find_plugin(os.path.dirname(__file__), name, loader) +# self.BACKENDS[name] = klass + +# backend = klass(*args, **kwargs) + +# if getattr(backend, "name", None) is None: +# backend.name = name + +# return backend + +# def __getattr__(self, name): +# return self(name.replace("_", "-")) + + +# get_backend = BackendMaker() + +CONFIG.on_change(MANAGER.update) diff --git a/src/earthkit/regrid/db.py b/src/earthkit/regrid/backends/db.py similarity index 96% rename from src/earthkit/regrid/db.py rename to src/earthkit/regrid/backends/db.py index 0fc619e..1d66330 100644 --- a/src/earthkit/regrid/db.py +++ b/src/earthkit/regrid/backends/db.py @@ -542,6 +542,10 @@ def matrix_source(self): def from_path(path): return MatrixDb(LocalAccessor(path)) + @staticmethod + def from_url(url): + return MatrixDb(UrlAccessor(url)) + def __len__(self): return len(self.index) @@ -556,22 +560,22 @@ def _reset(self): SYS_DB = MatrixDb(SystemAccessor()) -DB_LIST = [SYS_DB] +# DB_LIST = [SYS_DB] -def add_matrix_source(path): - global DB_LIST - for item in DB_LIST[1:]: - if item.matrix_source() == path: - return item - db = MatrixDb.from_path(path) - DB_LIST.append(db) - return db +# def add_matrix_source(path): +# global DB_LIST +# for item in DB_LIST[1:]: +# if item.matrix_source() == path: +# return item +# db = MatrixDb.from_path(path) +# DB_LIST.append(db) +# return db -def find(*args, matrix_source=None, **kwargs): - if matrix_source is None: - return SYS_DB.find(*args, **kwargs) - else: - db = add_matrix_source(matrix_source) - return db.find(*args, **kwargs) +# def find(*args, matrix_source=None, **kwargs): +# if matrix_source is None: +# return SYS_DB.find(*args, **kwargs) +# else: +# db = add_matrix_source(matrix_source) +# return db.find(*args, **kwargs) diff --git a/src/earthkit/regrid/backends/matrix.py b/src/earthkit/regrid/backends/matrix.py new file mode 100644 index 0000000..3fcfef3 --- /dev/null +++ b/src/earthkit/regrid/backends/matrix.py @@ -0,0 +1,65 @@ +# (C) Copyright 2023 ECMWF. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +# + +from abc import abstractmethod +from functools import cached_property + +from . import Backend + + +class MatrixBackend(Backend): + def __init__(self, path_or_url=None): + self.path_or_url = path_or_url + + def interpolate(self, values, in_grid, out_grid, method, **kwargs): + z, shape = self.db.find(in_grid, out_grid, method, **kwargs) + + if z is None: + raise ValueError(f"No matrix found! {in_grid=} {out_grid=} {method=}") + + # This should check for 1D (GG) and 2D (LL) matrices + values = values.reshape(-1, 1) + + print("values.shape", values.shape) + print("z.shape", z.shape) + + values = z @ values + + print("values.shape", values.shape) + + return values.reshape(shape) + + @property + @abstractmethod + def db(self): + pass + + +class LocalMatrixBackend(MatrixBackend): + @cached_property + def db(self): + from .db import MatrixDb + + return MatrixDb.from_path(self.path_or_url) + + +class RemoteMatrixBackend(MatrixBackend): + @cached_property + def db(self): + from .db import MatrixDb + + return MatrixDb.from_url(self.path_or_url) + + +class SystemRemoteMatrixBackend(RemoteMatrixBackend): + @cached_property + def db(self): + from .db import SYS_DB + + return SYS_DB diff --git a/src/earthkit/regrid/backends/mir.py b/src/earthkit/regrid/backends/mir.py new file mode 100644 index 0000000..fb2f70c --- /dev/null +++ b/src/earthkit/regrid/backends/mir.py @@ -0,0 +1,20 @@ +# (C) Copyright 2023 ECMWF. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +# + +from . import Backend + + +class MirBackend(Backend): + def interpolate(self, values, in_grid, out_grid, method, **kwargs): + try: + import mir + except ImportError: + raise ImportError("The 'mir' package is required for this operation") + + return mir.interpolate(values, in_grid, out_grid, method, **kwargs) diff --git a/src/earthkit/regrid/interpolate.py b/src/earthkit/regrid/interpolate.py index fbe4d46..44b24be 100644 --- a/src/earthkit/regrid/interpolate.py +++ b/src/earthkit/regrid/interpolate.py @@ -7,15 +7,17 @@ # nor does it submit to any jurisdiction. # -from earthkit.regrid.db import find +import logging +LOG = logging.getLogger(__name__) -def interpolate(values, in_grid=None, out_grid=None, method="linear", **kwargs): + +def interpolate(values, in_grid=None, out_grid=None, method="linear", backend=None, **kwargs): interpolator = _find_interpolator(values) if interpolator is None: raise ValueError(f"Cannot interpolate data with type={type(values)}") - return interpolator(values, in_grid=in_grid, out_grid=out_grid, method=method, **kwargs) + return interpolator(values, in_grid=in_grid, out_grid=out_grid, method=method, backend=backend, **kwargs) def _find_interpolator(values): @@ -25,21 +27,34 @@ def _find_interpolator(values): return None -def _interpolate(values, in_grid, out_grid, method, **kwargs): - z, shape = find(in_grid, out_grid, method, **kwargs) +class Interpolator: + @staticmethod + def _interpolate(values, in_grid, out_grid, **kwargs): + from earthkit.regrid.backends import MANAGER - if z is None: - raise ValueError(f"No matrix found! {in_grid=} {out_grid=} {method=}") + method = kwargs.pop("method") + backend = kwargs.pop("backend") + backends = MANAGER.backends(backend) - # This should check for 1D (GG) and 2D (LL) matrices - values = values.reshape(-1, 1) + if not backends: + raise ValueError(f"No backend found for {backend}") - values = z @ values + if len(backends) == 1: + return backends[0].interpolate(values, in_grid, out_grid, method, **kwargs) + else: + errors = [] + for b in backends: + LOG.debug(f"Trying backend {b}") + print(f"Trying backend {b}") + try: + return b.interpolate(values, in_grid, out_grid, method, **kwargs) + except Exception as e: + errors.append(e) - return values.reshape(shape) + raise ValueError("No backend could interpolate the data", errors) -class NumpyInterpolator: +class NumpyInterpolator(Interpolator): @staticmethod def match(values): import numpy as np @@ -49,18 +64,22 @@ def match(values): def __call__(self, values, **kwargs): in_grid = kwargs.pop("in_grid") out_grid = kwargs.pop("out_grid") - method = kwargs.pop("method") - return _interpolate(values, in_grid, out_grid, method, **kwargs) + return self._interpolate(values, in_grid, out_grid, **kwargs) -class FieldListInterpolator: +class FieldListInterpolator(Interpolator): @staticmethod def match(values): + from earthkit.regrid.utils import is_module_loaded + + if not is_module_loaded("earthkit.data"): + return False + try: import earthkit.data return isinstance(values, earthkit.data.FieldList) - except ImportError: + except Exception: return False def __call__(self, values, **kwargs): @@ -71,16 +90,15 @@ def __call__(self, values, **kwargs): # if in_grid is not None: # raise ValueError(f"in_grid {in_grid} cannot be used for FieldList interpolation") out_grid = kwargs.pop("out_grid") - method = kwargs.pop("method") r = earthkit.data.FieldList() for f in ds: vv = f.to_numpy(flatten=True) - v_res = _interpolate( + v_res = self._interpolate( vv, f.metadata().gridspec if in_grid is None else in_grid, out_grid, - method, + # method, **kwargs, ) md_res = f.metadata().override(gridspec=out_grid) diff --git a/src/earthkit/regrid/utils/__init__.py b/src/earthkit/regrid/utils/__init__.py index 2d46ff5..825f8c8 100644 --- a/src/earthkit/regrid/utils/__init__.py +++ b/src/earthkit/regrid/utils/__init__.py @@ -6,6 +6,7 @@ # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. # +import sys try: # There is a bug in tqdm that expects ipywidgets @@ -59,3 +60,7 @@ def no_progress_bar(total, initial=0, desc=None): leave=False, desc=desc, ) + + +def is_module_loaded(module_name): + return module_name in sys.modules diff --git a/src/earthkit/regrid/utils/builder.py b/src/earthkit/regrid/utils/builder.py new file mode 100644 index 0000000..24dc476 --- /dev/null +++ b/src/earthkit/regrid/utils/builder.py @@ -0,0 +1,152 @@ +# (C) Copyright 2023 ECMWF. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +# + +import hashlib +import json +import os + +from scipy.sparse import load_npz + +from earthkit.regrid.backends.db import VERSION +from earthkit.regrid.backends.db import MatrixIndex + +from .matrix import matrix_memory_size +from .mir import mir_cached_matrix_to_file + + +def regular_ll(entry): + bb = entry["bbox"] + d = { + "grid": [entry["increments"][x] for x in ("west_east", "south_north")], + "shape": [entry["nj"], entry["ni"]], + "area": [bb["north"], bb["west"], bb["south"], bb["east"]], + } + if "global" in entry: + d["global"] = entry["global"] + return d + + +def reduced_gg(entry): + pl = entry["pl"] + G = "O" if pl[1] - pl[0] == 4 else "N" + N = entry["N"] + bb = entry["bbox"] + + d = { + "grid": f"{G}{N}", + "shape": [sum(pl)], + "area": [bb["north"], bb["west"], bb["south"], bb["east"]], + } + + if "global" in entry: + d["global"] = entry["global"] + + return d + + +def healpix(entry): + d = {"grid": entry["grid"], "ordering": entry["ordering"]} + return d + + +def orca(entry): + d = {"grid": entry["unstructuredGridType"] + "_" + entry["unstructuredGridSubtype"]} + return d + + +def make_sha(d): + m = hashlib.sha256() + m.update(json.dumps(d, sort_keys=True).encode("utf-8")) + return m.hexdigest() + + +def get_method_name(entry): + method = entry["interpolation"]["method"] + if isinstance(method, dict): + method = method["type"] + return method + + +def make_matrix(input_path, output_path, index_file=None, global_input=None, global_output=None): + + with open(input_path) as f: + entry = json.load(f) + + inter_ori = dict(entry["interpolation"]) + method_name = MatrixIndex.interpolation_method_name(entry) + + uid = MatrixIndex.make_interpolation_uid(entry) + if uid != method_name: + entry["interpolation"]["_uid"] = uid + else: + entry["interpolation"]["method"] = method_name + + # create output folder + matrix_output_path = os.path.join( + output_path, + MatrixIndex.matrix_dir_name(entry), + # f"{inter_engine}_{inter_version}_{inter_uid}" + ) + os.makedirs(matrix_output_path, exist_ok=True) + + # create matrix + cache_file = entry["matrix"].pop("cache_file") + m = {} + m["input"] = entry["input"] + m["output"] = entry["output"] + m["interpolation"] = inter_ori + + key = make_sha(m) + name = key + + print(f"entry={entry}") + npz_file = os.path.join(matrix_output_path, f"{name}.npz") + mir_cached_matrix_to_file(cache_file, npz_file) + + if index_file is None: + index_file = os.path.join(output_path, "index.json") + + if os.path.exists(index_file): + with open(index_file) as f: + index = json.load(f) + if index.get("version", None) != VERSION: + raise ValueError(f"{index_file=} version must be {VERSION}") + else: + index = {} + index["version"] = VERSION + index["matrix"] = {} + + def convert(x): + proc = globals()[x["type"]] + return proc(x) + + if global_input is not None and "global" not in entry["input"]: + entry["input"]["global"] = 1 if global_input else 0 + + if global_output is not None and "global" not in entry["output"]: + entry["output"]["global"] = 1 if global_output else 0 + + # get matrix size + z = load_npz(npz_file) + mem_size = matrix_memory_size(z) + z = None + + index["matrix"][key] = dict( + input=convert(entry["input"]), + output=convert(entry["output"]), + interpolation=entry["interpolation"], + nnz=entry["matrix"]["nnz"], + memory=mem_size, + ) + + with open(index_file, "w") as f: + json.dump(index, f, indent=4) + + print("Written", npz_file) + print("Written", index_file) diff --git a/src/earthkit/regrid/utils/config.py b/src/earthkit/regrid/utils/config.py index 18d87c3..d6fab84 100644 --- a/src/earthkit/regrid/utils/config.py +++ b/src/earthkit/regrid/utils/config.py @@ -44,27 +44,42 @@ def check(self, value): def explain(self): pass - @staticmethod - def make(value): - v = _validators.get(type(value), None) - if v is not None: - return v(value) - else: - raise TypeError(f"Cannot create Validator for type={type(value)}") + # @staticmethod + # def make(value): + # v = _validators.get(type(value), None) + # if v is not None: + # return v(value) + # else: + # raise TypeError(f"Cannot create Validator for type={type(value)}") + + +# class ValueValidator(Validator): +# def __init__(self, value): +# self.value = value +# def check(self, value): +# return value == self.value -class ValueValidator(Validator): - def __init__(self, value): - self.value = value +# def explain(self): +# return f"Valid when = {self.value}." + + +class ValuesValidator(Validator): + """Check if value is in a list of valid values""" + + def __init__(self, values): + self.values = values def check(self, value): - return value == self.value + return value in self.values def explain(self): - return f"Valid when = {self.value}." + return f"Valid values: {list_to_human(self.values)}." class ListValidator(Validator): + """Check if a list of values is in a list of valid values""" + def __init__(self, values): self.values = values @@ -75,7 +90,7 @@ def explain(self): return f"Valid values: {list_to_human(self.values)}." -_validators = {bool: ValueValidator, list: ListValidator} +# _validators = {bool: ValueValidator, list: ValuesValidator} class ConfigOption: @@ -153,7 +168,7 @@ def validate(self, name, value): "user", """Caching policy. {validator} See :doc:`/guide/caching` for more information. """, - validator=ListValidator(["off", "temporary", "user"]), + validator=ValuesValidator(["off", "temporary", "user"]), ), "maximum-cache-size": _( "5GB", @@ -194,13 +209,40 @@ def validate(self, name, value): "largest", """The matrix in-memory cache policy. {validator} See :ref:`memory_cache` for more information.""", - validator=ListValidator(["off", "unlimited", "largest", "lru"]), + validator=ValuesValidator(["off", "unlimited", "largest", "lru"]), ), "matrix-memory-cache-strict-mode": _( False, """Raise exception if the matrix cannot be fitted into the in-memory cache. Only used when ``matrix-memory-cache-policy`` is ``"largest"`` or ``"lru"``. See :ref:`memory_cache` for more information.""", ), + "local-matrix-directories": _( + None, + """Parent of the cache directory when ``cache-policy`` is ``temporary``. + See :doc:`/guide/caching` for more information.""", + getter="_as_str", + none_ok=True, + ), + "remote-matrix-directories": _( + None, + """Parent of the cache directory when ``cache-policy`` is ``temporary``. + See :doc:`/guide/caching` for more information.""", + getter="_as_str", + none_ok=True, + ), + # "backend-policy": _( + # "all", + # """The interpolation backend policy. {validator} + # See :ref:`backend` for more information.""", + # validator=ListValidator(["local_matrix", "matrix", "mir", "all"]), + # ), + "backend-order": _( + None, + """The interpolation backend order. {validator} + See :ref:`backend` for more information.""", + getter="_as_list", + none_ok=True, + ), } @@ -536,6 +578,14 @@ def _as_int(self, name, value, none_ok): value = value.replace('"', "").replace("'", "").strip() return int(value) + def _as_list(self, name, value, none_ok): + if value is None and none_ok: + return [] + if isinstance(value, str): + value = value.replace('"', "").replace("'", "").strip() + return value.split(",") + return list(value) + @forward def temporary(self, *args, config_yaml=None, **kwargs): tmp = Config(config_yaml, self._config) diff --git a/src/earthkit/regrid/utils/matrix.py b/src/earthkit/regrid/utils/matrix.py index ace1865..03a1d71 100644 --- a/src/earthkit/regrid/utils/matrix.py +++ b/src/earthkit/regrid/utils/matrix.py @@ -7,70 +7,6 @@ # nor does it submit to any jurisdiction. # -import hashlib -import json -import os - -from scipy.sparse import load_npz - -from earthkit.regrid.db import VERSION -from earthkit.regrid.db import MatrixIndex - -from .mir import mir_cached_matrix_to_file - - -def regular_ll(entry): - bb = entry["bbox"] - d = { - "grid": [entry["increments"][x] for x in ("west_east", "south_north")], - "shape": [entry["nj"], entry["ni"]], - "area": [bb["north"], bb["west"], bb["south"], bb["east"]], - } - if "global" in entry: - d["global"] = entry["global"] - return d - - -def reduced_gg(entry): - pl = entry["pl"] - G = "O" if pl[1] - pl[0] == 4 else "N" - N = entry["N"] - bb = entry["bbox"] - - d = { - "grid": f"{G}{N}", - "shape": [sum(pl)], - "area": [bb["north"], bb["west"], bb["south"], bb["east"]], - } - - if "global" in entry: - d["global"] = entry["global"] - - return d - - -def healpix(entry): - d = {"grid": entry["grid"], "ordering": entry["ordering"]} - return d - - -def orca(entry): - d = {"grid": entry["unstructuredGridType"] + "_" + entry["unstructuredGridSubtype"]} - return d - - -def make_sha(d): - m = hashlib.sha256() - m.update(json.dumps(d, sort_keys=True).encode("utf-8")) - return m.hexdigest() - - -def get_method_name(entry): - method = entry["interpolation"]["method"] - if isinstance(method, dict): - method = method["type"] - return method - def matrix_memory_size(m): # see: https://stackoverflow.com/questions/11173019/determining-the-byte-size-of-a-scipy-sparse-matrix @@ -81,81 +17,3 @@ def matrix_memory_size(m): except Exception as e: print(e) return 0 - - -def make_matrix(input_path, output_path, index_file=None, global_input=None, global_output=None): - with open(input_path) as f: - entry = json.load(f) - - inter_ori = dict(entry["interpolation"]) - method_name = MatrixIndex.interpolation_method_name(entry) - - uid = MatrixIndex.make_interpolation_uid(entry) - if uid != method_name: - entry["interpolation"]["_uid"] = uid - else: - entry["interpolation"]["method"] = method_name - - # create output folder - matrix_output_path = os.path.join( - output_path, - MatrixIndex.matrix_dir_name(entry), - # f"{inter_engine}_{inter_version}_{inter_uid}" - ) - os.makedirs(matrix_output_path, exist_ok=True) - - # create matrix - cache_file = entry["matrix"].pop("cache_file") - m = {} - m["input"] = entry["input"] - m["output"] = entry["output"] - m["interpolation"] = inter_ori - - key = make_sha(m) - name = key - - print(f"entry={entry}") - npz_file = os.path.join(matrix_output_path, f"{name}.npz") - mir_cached_matrix_to_file(cache_file, npz_file) - - if index_file is None: - index_file = os.path.join(output_path, "index.json") - - if os.path.exists(index_file): - with open(index_file) as f: - index = json.load(f) - if index.get("version", None) != VERSION: - raise ValueError(f"{index_file=} version must be {VERSION}") - else: - index = {} - index["version"] = VERSION - index["matrix"] = {} - - def convert(x): - proc = globals()[x["type"]] - return proc(x) - - if global_input is not None and "global" not in entry["input"]: - entry["input"]["global"] = 1 if global_input else 0 - - if global_output is not None and "global" not in entry["output"]: - entry["output"]["global"] = 1 if global_output else 0 - - # get matrix size - z = load_npz(npz_file) - mem_size = matrix_memory_size(z) - z = None - - index["matrix"][key] = dict( - input=convert(entry["input"]), - output=convert(entry["output"]), - interpolation=entry["interpolation"], - nnz=entry["matrix"]["nnz"], - memory=mem_size, - ) - - with open(index_file, "w") as f: - json.dump(index, f, indent=4) - - print("Written", npz_file) - print("Written", index_file) diff --git a/src/earthkit/regrid/utils/memcache.py b/src/earthkit/regrid/utils/memcache.py index b01cc4c..373cc7c 100644 --- a/src/earthkit/regrid/utils/memcache.py +++ b/src/earthkit/regrid/utils/memcache.py @@ -26,6 +26,8 @@ def matrix_size(m): + from earthkit.regrid.utils.builder import matrix_memory_size + m = m[0] try: return matrix_memory_size(m) diff --git a/src/earthkit/regrid/utils/plugins.py b/src/earthkit/regrid/utils/plugins.py new file mode 100644 index 0000000..f4af6a1 --- /dev/null +++ b/src/earthkit/regrid/utils/plugins.py @@ -0,0 +1,200 @@ +# (C) Copyright 2020 ECMWF. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +# + +"""Plugins rely on the python plugins_ system using ``entry_points``. + +.. _plugins: https://packaging.python.org/guides/creating-and-discovering-plugins/ + +""" + +import logging +import os +import sys +from collections import defaultdict +from importlib import import_module +from typing import List +from typing import Union + +import entrypoints + +import earthkit.regrid +from earthkit.regrid.utils.humanize import did_you_mean + +LOG = logging.getLogger(__name__) + +CACHE = {} + +PLUGINS = {} + +REGISTERED = defaultdict(dict) + +AVAILABLE_KINDS = ["backend"] + + +def refresh(kind=None): + if kind in PLUGINS: + PLUGINS.pop(kind) + return + if kind is None: + PLUGINS.clear() + return + assert kind in AVAILABLE_KINDS, (kind, AVAILABLE_KINDS) + + +def _load_plugins(kind): + plugins = {} + for e in entrypoints.get_group_all(f"earthkit.regrid.{kind}s"): + plugins[e.name.replace("_", "-")] = e + return plugins + + +def load_plugins(kind): + """Loads the plugins for a given kind. The plugin needs to have registered itself with entry_point. + + Parameters + ---------- + kind : str + Plugin type such as "backend". + """ + if PLUGINS.get(kind) is None: + PLUGINS[kind] = _load_plugins(kind) + return PLUGINS[kind] + + +def find_plugin(directories: Union[str, List[str]], name: str, loader, refreshed=False): + """Find a plugin by name . + + Parameters + ---------- + directories : list or str + List of directories to be searched to find the plugin. + name : str + Name of the plugin + loader : class + Class implementing load_yaml() and load_module() + + Returns + ------- + Return what the loader will returns when applied to the plugin with the right name + `name`, found in one of the directories of the `directories` list. + + Raises + ------ + NameError + If plugin is not found. + """ + # noqa: E501 + candidates = set() + + if name in REGISTERED[loader.kind]: + return getattr(REGISTERED[loader.kind][name], loader.kind) + + candidates.update(REGISTERED[loader.kind].keys()) + + plugins = load_plugins(loader.kind) + + if name in plugins: + plugin = plugins[name] + return loader.load_entry(plugin) + + candidates.update(plugins.keys()) + + if not isinstance(directories, (tuple, list)): + directories = [directories] + + for directory in directories: + n = len(directory) + for path, _, files in os.walk(directory): + path = path[n:] + for f in files: + base, ext = os.path.splitext(f) + if ext == ".yaml": + candidates.add(base) + if base == name: + full = os.path.join(directory, path, f) + return loader.load_yaml(full) + + if ext == ".py" and base[0] != "_": + full = os.path.join(path, base) + + if sys.platform == "win32": + full = full.replace("\\", "/") + + if full[0] != "/": + full = "/" + full + + p = full[1:].replace("/", "-").replace("_", "-") + candidates.add(p) + if p == name: + return loader.load_module(full.replace("/", ".")) + + if not refreshed: + LOG.debug("Cannot find {loader.kind} 'name'. Refreshing plugin list.") + refresh(loader.kind) + return find_plugin(directories, name, loader, refreshed=True) + + module = loader.load_remote(name) + if module is not None: + return module + + correction = did_you_mean(name, candidates) + + if correction is not None: + LOG.warning( + "Cannot find %s '%s', did you mean '%s'?", + loader.kind, + name, + correction, + ) + + candidates = ", ".join(sorted(c for c in candidates if "-" in c)) + + raise NameError(f"Cannot find {loader.kind} '{name}' (values are: {candidates})") + + +def directories(owner: bool = False) -> list: + """Return a list of directories that are used in the project . + + If owner = False, return a list of directories where to search for plugins. + + If owner = True, return a list of 2-uples to include the owner in the return value. + + Parameters + ---------- + owner : bool, optional + + """ + result = [] + for kind in ["backend"]: + for name, v in load_plugins(kind).items(): + try: + module = import_module(v.module_name) + result.append((name, os.path.dirname(module.__file__))) + except Exception: + LOG.error("Cannot load module %s", v.module_name, exc_info=True) + + result.append(("earthkit.regrid", os.path.dirname(earthkit.regrid.__file__))) + + if owner: + return result + + return [x[1] for x in result] + + +class RegisteredPlugin: + pass + + +def register(kind, name, proc): + assert name not in REGISTERED[kind], (kind, name, REGISTERED) + if not hasattr(proc, kind): + o = RegisteredPlugin() + setattr(o, kind, proc) + proc = o + REGISTERED[kind][name] = proc diff --git a/tests/conftest.py b/tests/conftest.py index 1425b31..1be162e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -12,7 +12,7 @@ def pytest_runtest_setup(item): marks_in_items = list([m.name for m in item.iter_markers()]) - from earthkit.regrid.db import SYS_DB + from earthkit.regrid.backends.db import SYS_DB SYS_DB._clear_index() diff --git a/tests/test_gridspec.py b/tests/test_gridspec.py index de16d65..ee2a77c 100644 --- a/tests/test_gridspec.py +++ b/tests/test_gridspec.py @@ -8,7 +8,7 @@ import pytest -from earthkit.regrid.db import SYS_DB +from earthkit.regrid.backends.db import SYS_DB @pytest.mark.parametrize( diff --git a/tests/test_local.py b/tests/test_local.py index fe6babc..5829d4f 100644 --- a/tests/test_local.py +++ b/tests/test_local.py @@ -12,7 +12,6 @@ import pytest from earthkit.regrid import interpolate -from earthkit.regrid.db import add_matrix_source DB_PATH = os.path.join(os.path.dirname(__file__), "data", "local", "db") DATA_PATH = os.path.join(os.path.dirname(__file__), "data", "local") @@ -23,20 +22,33 @@ def file_in_testdir(filename): return os.path.join(DATA_PATH, filename) -def run_interpolate(mode): - v_in = np.load(file_in_testdir("in_N32.npz"))["arr_0"] - np.load(file_in_testdir(f"out_N32_10x10_{mode}.npz"))["arr_0"] - interpolate( - v_in, - {"grid": "N32"}, - {"grid": [10, 10]}, - matrix_source=DB_PATH, - method=mode, - ) +def get_local_db(): + from earthkit.regrid.backends.db import MatrixDb + + return MatrixDb.from_path(DB_PATH) + + +def run_interpolate(v_in, in_grid, out_grid, method): + from earthkit.regrid import config + + with config.temporary(local_matrix_directories=DB_PATH, backend_order=["local-matrix"]): + # v_in = np.load(file_in_testdir("in_N32.npz"))["arr_0"] + # np.load(file_in_testdir(f"out_N32_10x10_{method}.npz"))["arr_0"] + + # in_grid = in_grid or {"grid": "N32"} + # out_grid = out_grid or {"grid": [10, 10]} + + return interpolate( + v_in, + in_grid, + out_grid, + # matrix_source=DB_PATH, + method=method, + ) def test_local_index(): - DB = add_matrix_source(DB_PATH) + DB = get_local_db() # we have an extra unsupported entry in the index file. We have # to be sure the DB is loaded correctly bypassing the unsupported # entry. @@ -71,7 +83,7 @@ def test_local_index(): def test_local_ll_to_ll(method): v_in = np.load(file_in_testdir("in_5x5.npz"))["arr_0"] v_ref = np.load(file_in_testdir(f"out_5x5_10x10_{method}.npz"))["arr_0"] - v_res = interpolate(v_in, {"grid": [5, 5]}, {"grid": [10, 10]}, matrix_source=DB_PATH, method=method) + v_res = run_interpolate(v_in, in_grid={"grid": [5, 5]}, out_grid={"grid": [10, 10]}, method=method) assert v_res.shape == (19, 36) assert np.allclose(v_res.flatten(), v_ref) @@ -81,7 +93,7 @@ def test_local_ll_to_ll(method): def test_local_ogg_to_ll(method): v_in = np.load(file_in_testdir("in_O32.npz"))["arr_0"] v_ref = np.load(file_in_testdir(f"out_O32_10x10_{method}.npz"))["arr_0"] - v_res = interpolate(v_in, {"grid": "O32"}, {"grid": [10, 10]}, matrix_source=DB_PATH, method=method) + v_res = run_interpolate(v_in, in_grid={"grid": "O32"}, out_grid={"grid": [10, 10]}, method=method) assert v_res.shape == (19, 36) assert np.allclose(v_res.flatten(), v_ref) @@ -91,11 +103,10 @@ def test_local_ogg_to_ll(method): def test_local_ngg_to_ll(method): v_in = np.load(file_in_testdir("in_N32.npz"))["arr_0"] v_ref = np.load(file_in_testdir(f"out_N32_10x10_{method}.npz"))["arr_0"] - v_res = interpolate( + v_res = run_interpolate( v_in, - {"grid": "N32"}, - {"grid": [10, 10]}, - matrix_source=DB_PATH, + in_grid={"grid": "N32"}, + out_grid={"grid": [10, 10]}, method=method, ) @@ -107,11 +118,10 @@ def test_local_ngg_to_ll(method): def test_local_healpix_ring_to_ll(method): v_in = np.load(file_in_testdir("in_H4_ring.npz"))["arr_0"] v_ref = np.load(file_in_testdir(f"out_H4_ring_10x10_{method}.npz"))["arr_0"] - v_res = interpolate( + v_res = run_interpolate( v_in, - {"grid": "H4", "ordering": "ring"}, - {"grid": [10, 10]}, - matrix_source=DB_PATH, + in_grid={"grid": "H4", "ordering": "ring"}, + out_grid={"grid": [10, 10]}, method=method, ) @@ -123,11 +133,10 @@ def test_local_healpix_ring_to_ll(method): def test_local_healpix_nested_to_ll(method): v_in = np.load(file_in_testdir("in_H4_nested.npz"))["arr_0"] v_ref = np.load(file_in_testdir(f"out_H4_nested_10x10_{method}.npz"))["arr_0"] - v_res = interpolate( + v_res = run_interpolate( v_in, - {"grid": "H4", "ordering": "nested"}, - {"grid": [10, 10]}, - matrix_source=DB_PATH, + in_grid={"grid": "H4", "ordering": "nested"}, + out_grid={"grid": [10, 10]}, method=method, ) @@ -201,7 +210,7 @@ def test_local_healpix_nested_to_ll(method): ], ) def test_local_gridspec_ok(gs_in, gs_out): - DB = add_matrix_source(DB_PATH) + DB = get_local_db() r = DB.find_entry(gs_in, gs_out, "linear") assert r, f"gs_in={gs_in} gs_out={gs_out}" @@ -241,7 +250,7 @@ def test_local_gridspec_ok(gs_in, gs_out): ], ) def test_local_gridspec_bad(gs_in, gs_out, err): - DB = add_matrix_source(DB_PATH) + DB = get_local_db() if err: with pytest.raises(err): r = DB.find_entry(gs_in, gs_out, "linear") diff --git a/tests/test_memcache.py b/tests/test_memcache.py index 056083f..33c6145 100644 --- a/tests/test_memcache.py +++ b/tests/test_memcache.py @@ -25,20 +25,23 @@ def file_in_testdir(filename): def run_interpolate(mode): - v_in = np.load(file_in_testdir("in_N32.npz"))["arr_0"] - np.load(file_in_testdir(f"out_N32_10x10_{mode}.npz"))["arr_0"] - interpolate( - v_in, - {"grid": "N32"}, - {"grid": [10, 10]}, - matrix_source=DB_PATH, - method=mode, - ) + from earthkit.regrid import config + + with config.temporary(local_matrix_directories=DB_PATH, backend_order=["local-matrix"]): + + v_in = np.load(file_in_testdir("in_N32.npz"))["arr_0"] + np.load(file_in_testdir(f"out_N32_10x10_{mode}.npz"))["arr_0"] + interpolate( + v_in, + {"grid": "N32"}, + {"grid": [10, 10]}, + method=mode, + ) @pytest.fixture def patch_estimate_matrix_memory(monkeypatch): - from earthkit.regrid.db import MatrixIndex + from earthkit.regrid.backends.db import MatrixIndex def patched_estimate_memory(self): return 200000 diff --git a/tests/test_remote_index.py b/tests/test_remote_index.py index 6723b0d..b4cb1c4 100644 --- a/tests/test_remote_index.py +++ b/tests/test_remote_index.py @@ -18,7 +18,7 @@ @pytest.mark.download @pytest.mark.tmp_cache def test_remote_index_handling(): - from earthkit.regrid.db import SYS_DB + from earthkit.regrid.backends.db import SYS_DB method = "linear" diff --git a/tools/utils/matrix.py b/tools/utils/matrix.py index a914e52..9a11bff 100644 --- a/tools/utils/matrix.py +++ b/tools/utils/matrix.py @@ -118,7 +118,7 @@ def create_matrix_files( if add_to_index: # process matrix and add it to index json file - from earthkit.regrid.utils.matrix import make_matrix + from earthkit.regrid.utils.builder import make_matrix make_matrix( matrix_json,