Skip to content
This repository has been archived by the owner on May 2, 2024. It is now read-only.

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
cfkanesan committed Nov 17, 2023
1 parent 7c44e00 commit 2af735b
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 27 deletions.
19 changes: 9 additions & 10 deletions src/idpi/data_source.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""Data source helper class."""

# Standard library
import dataclasses as dc
import sys
Expand Down Expand Up @@ -55,25 +57,22 @@ def query(self, request):
raise NotImplementedError(f"request of type {type(request)} not supported.")

@query.register
def _(self, request: mars.Request):
def _(self, request: dict):
# The presence of the yield keyword makes this def a generator.
# As a result, the context manager will remain active until the
# exhaustion of the data source iterator.
grib_def = config.get("data_scope", GRIB_DEF[request.model])
req_kwargs = self.request_template | request
req = mars.Request(**req_kwargs)

grib_def = config.get("data_scope", GRIB_DEF[req.model])
with grib_def_ctx(grib_def):
if self.datafiles:
fs = ekd.from_source("file", self.datafiles)
source = fs.sel(request.dump(exclude_defaults=True))
source = fs.sel(req_kwargs)
else:
source = ekd.from_source("fdb", request.to_fdb())
source = ekd.from_source("fdb", req.to_fdb())
yield from source

@query.register
def _(self, request: dict):
req_kwargs = self.request_template | request
req = mars.Request(**req_kwargs)
yield from self.query(req)

@query.register
def _(self, request: str):
yield from self.query({"param": request})
Expand Down
13 changes: 9 additions & 4 deletions src/idpi/mars.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
from pydantic import dataclasses as pdc


ValidationError = pydantic.ValidationError


class Class(str, Enum):
OPERATIONAL_DATA = "od"

Expand Down Expand Up @@ -71,7 +74,7 @@ class Request:

expver: str = "0001"
levelist: int | tuple[int, ...] | None = None
number: int | tuple[int, ...] = 1
number: int | tuple[int, ...] = 0
step: int | tuple[int, ...] = 0

class_: Class = dc.field(
Expand All @@ -83,20 +86,22 @@ class Request:
stream: Stream = Stream.ENS_FORECAST
type: Type = Type.ENS_MEMBER

def dump(self, exclude_defaults: bool = False):
def dump(self):
root = pydantic.RootModel(self)
return root.model_dump(
mode="json",
by_alias=True,
exclude_none=True,
exclude_defaults=exclude_defaults,
)

def to_fdb(self):
mapping = _load_mapping()
param_id = mapping[self.param]["cosmo"]["paramId"]
staggered = mapping[self.param]["cosmo"].get("vertStag", False)

if self.date is None or self.time is None:
raise RuntimeError("date and time are required fields for FDB.")

if self.levelist is None and self.levtype == LevType.MODEL_LEVEL:
n_lvl = N_LVL[self.model]
if staggered:
Expand All @@ -106,4 +111,4 @@ def to_fdb(self):
levelist = self.levelist

obj = dc.replace(self, param=param_id, levelist=levelist)
return obj.dump(exclude_defaults=False)
return obj.dump()
59 changes: 54 additions & 5 deletions tests/test_idpi/test_data_source.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,74 @@
from contextlib import nullcontext
from unittest.mock import patch, call

from idpi import data_source, mars
import pytest

from idpi import config, data_source, mars

@patch.object(data_source.ekd, "from_source")
def test_query_files(mock_from_source):

@pytest.fixture
def mock_from_source():
with patch.object(data_source.ekd, "from_source") as mock:
yield mock


@pytest.fixture
def mock_grib_def_ctx():
with patch.object(data_source, "grib_def_ctx") as mock:
mock.return_value = nullcontext()
yield mock


def test_query_files(mock_from_source, mock_grib_def_ctx):
datafiles = ["foo"]
param = "bar"

ds = data_source.DataSource(datafiles)
for _ in ds.query(param):
pass

assert mock_grib_def_ctx.mock_calls == [call("cosmo")]
assert mock_from_source.mock_calls == [
call("file", datafiles),
call().sel({"param": param}),
call().sel().__iter__(),
]


def test_query_files_tuple(mock_from_source, mock_grib_def_ctx):
datafiles = ["foo"]
request = param, levtype = ("bar", "ml")

ds = data_source.DataSource(datafiles)
for _ in ds.query(request):
pass

assert mock_grib_def_ctx.mock_calls == [call("cosmo")]
assert mock_from_source.mock_calls == [
call("file", datafiles),
call().sel({"param": param, "levtype": levtype}),
call().sel().__iter__(),
]


def test_query_files_ifs(mock_from_source, mock_grib_def_ctx):
datafiles = ["foo"]
param = "bar"

with config.set_values(data_scope="ifs"):
ds = data_source.DataSource(datafiles)
for _ in ds.query(param):
pass

assert mock_grib_def_ctx.mock_calls == [call("ifs")]
assert mock_from_source.mock_calls == [
call("file", datafiles),
call().sel({"param": param}),
call().sel().__iter__(),
]


@patch.object(data_source.ekd, "from_source")
def test_query_fdb(mock_from_source):
def test_query_fdb(mock_from_source, mock_grib_def_ctx):
datafiles = []
param = "U"
template = {"date": "20200101", "time": "0000"}
Expand All @@ -29,6 +77,7 @@ def test_query_fdb(mock_from_source):
for _ in ds.query(param):
pass

assert mock_grib_def_ctx.mock_calls == [call("cosmo")]
assert mock_from_source.mock_calls == [
call("fdb", mars.Request(param, **template).to_fdb()),
call().__iter__(),
Expand Down
9 changes: 1 addition & 8 deletions tests/test_idpi/test_mars.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def sample():
"levtype": "ml",
"levelist": list(range(1, 81)),
"model": "COSMO-1E",
"number": 1,
"number": 0,
"stream": "enfo",
"param": 500028, # U
"time": "0000",
Expand Down Expand Up @@ -59,10 +59,3 @@ def test_fdb_sfc(sample):
def test_request_raises():
with pytest.raises(ValueError):
mars.Request("U", date="20200101", time="0000", model="undef")


def test_no_defaults():
observed = mars.Request("U").dump(exclude_defaults=True)
expected = {"param": "U"}

assert observed == expected

0 comments on commit 2af735b

Please sign in to comment.