From 2af735baf3d9493dac52288f1e01584a3fb8e41f Mon Sep 17 00:00:00 2001 From: Christian Kanesan Date: Fri, 17 Nov 2023 16:46:41 +0100 Subject: [PATCH] fix --- src/idpi/data_source.py | 19 +++++----- src/idpi/mars.py | 13 +++++-- tests/test_idpi/test_data_source.py | 59 ++++++++++++++++++++++++++--- tests/test_idpi/test_mars.py | 9 +---- 4 files changed, 73 insertions(+), 27 deletions(-) diff --git a/src/idpi/data_source.py b/src/idpi/data_source.py index 51914608..e9a37414 100644 --- a/src/idpi/data_source.py +++ b/src/idpi/data_source.py @@ -1,3 +1,5 @@ +"""Data source helper class.""" + # Standard library import dataclasses as dc import sys @@ -55,25 +57,22 @@ def query(self, request): raise NotImplementedError(f"request of type {type(request)} not supported.") @query.register - def _(self, request: mars.Request): + def _(self, request: dict): # The presence of the yield keyword makes this def a generator. # As a result, the context manager will remain active until the # exhaustion of the data source iterator. - grib_def = config.get("data_scope", GRIB_DEF[request.model]) + req_kwargs = self.request_template | request + req = mars.Request(**req_kwargs) + + grib_def = config.get("data_scope", GRIB_DEF[req.model]) with grib_def_ctx(grib_def): if self.datafiles: fs = ekd.from_source("file", self.datafiles) - source = fs.sel(request.dump(exclude_defaults=True)) + source = fs.sel(req_kwargs) else: - source = ekd.from_source("fdb", request.to_fdb()) + source = ekd.from_source("fdb", req.to_fdb()) yield from source - @query.register - def _(self, request: dict): - req_kwargs = self.request_template | request - req = mars.Request(**req_kwargs) - yield from self.query(req) - @query.register def _(self, request: str): yield from self.query({"param": request}) diff --git a/src/idpi/mars.py b/src/idpi/mars.py index a0b2e7f8..775ab844 100644 --- a/src/idpi/mars.py +++ b/src/idpi/mars.py @@ -12,6 +12,9 @@ from pydantic import dataclasses as pdc +ValidationError = pydantic.ValidationError + + class Class(str, Enum): OPERATIONAL_DATA = "od" @@ -71,7 +74,7 @@ class Request: expver: str = "0001" levelist: int | tuple[int, ...] | None = None - number: int | tuple[int, ...] = 1 + number: int | tuple[int, ...] = 0 step: int | tuple[int, ...] = 0 class_: Class = dc.field( @@ -83,13 +86,12 @@ class Request: stream: Stream = Stream.ENS_FORECAST type: Type = Type.ENS_MEMBER - def dump(self, exclude_defaults: bool = False): + def dump(self): root = pydantic.RootModel(self) return root.model_dump( mode="json", by_alias=True, exclude_none=True, - exclude_defaults=exclude_defaults, ) def to_fdb(self): @@ -97,6 +99,9 @@ def to_fdb(self): param_id = mapping[self.param]["cosmo"]["paramId"] staggered = mapping[self.param]["cosmo"].get("vertStag", False) + if self.date is None or self.time is None: + raise RuntimeError("date and time are required fields for FDB.") + if self.levelist is None and self.levtype == LevType.MODEL_LEVEL: n_lvl = N_LVL[self.model] if staggered: @@ -106,4 +111,4 @@ def to_fdb(self): levelist = self.levelist obj = dc.replace(self, param=param_id, levelist=levelist) - return obj.dump(exclude_defaults=False) + return obj.dump() diff --git a/tests/test_idpi/test_data_source.py b/tests/test_idpi/test_data_source.py index 17966c52..8b5b4452 100644 --- a/tests/test_idpi/test_data_source.py +++ b/tests/test_idpi/test_data_source.py @@ -1,10 +1,25 @@ +from contextlib import nullcontext from unittest.mock import patch, call -from idpi import data_source, mars +import pytest +from idpi import config, data_source, mars -@patch.object(data_source.ekd, "from_source") -def test_query_files(mock_from_source): + +@pytest.fixture +def mock_from_source(): + with patch.object(data_source.ekd, "from_source") as mock: + yield mock + + +@pytest.fixture +def mock_grib_def_ctx(): + with patch.object(data_source, "grib_def_ctx") as mock: + mock.return_value = nullcontext() + yield mock + + +def test_query_files(mock_from_source, mock_grib_def_ctx): datafiles = ["foo"] param = "bar" @@ -12,6 +27,40 @@ def test_query_files(mock_from_source): for _ in ds.query(param): pass + assert mock_grib_def_ctx.mock_calls == [call("cosmo")] + assert mock_from_source.mock_calls == [ + call("file", datafiles), + call().sel({"param": param}), + call().sel().__iter__(), + ] + + +def test_query_files_tuple(mock_from_source, mock_grib_def_ctx): + datafiles = ["foo"] + request = param, levtype = ("bar", "ml") + + ds = data_source.DataSource(datafiles) + for _ in ds.query(request): + pass + + assert mock_grib_def_ctx.mock_calls == [call("cosmo")] + assert mock_from_source.mock_calls == [ + call("file", datafiles), + call().sel({"param": param, "levtype": levtype}), + call().sel().__iter__(), + ] + + +def test_query_files_ifs(mock_from_source, mock_grib_def_ctx): + datafiles = ["foo"] + param = "bar" + + with config.set_values(data_scope="ifs"): + ds = data_source.DataSource(datafiles) + for _ in ds.query(param): + pass + + assert mock_grib_def_ctx.mock_calls == [call("ifs")] assert mock_from_source.mock_calls == [ call("file", datafiles), call().sel({"param": param}), @@ -19,8 +68,7 @@ def test_query_files(mock_from_source): ] -@patch.object(data_source.ekd, "from_source") -def test_query_fdb(mock_from_source): +def test_query_fdb(mock_from_source, mock_grib_def_ctx): datafiles = [] param = "U" template = {"date": "20200101", "time": "0000"} @@ -29,6 +77,7 @@ def test_query_fdb(mock_from_source): for _ in ds.query(param): pass + assert mock_grib_def_ctx.mock_calls == [call("cosmo")] assert mock_from_source.mock_calls == [ call("fdb", mars.Request(param, **template).to_fdb()), call().__iter__(), diff --git a/tests/test_idpi/test_mars.py b/tests/test_idpi/test_mars.py index d38481f3..7e51d50d 100644 --- a/tests/test_idpi/test_mars.py +++ b/tests/test_idpi/test_mars.py @@ -15,7 +15,7 @@ def sample(): "levtype": "ml", "levelist": list(range(1, 81)), "model": "COSMO-1E", - "number": 1, + "number": 0, "stream": "enfo", "param": 500028, # U "time": "0000", @@ -59,10 +59,3 @@ def test_fdb_sfc(sample): def test_request_raises(): with pytest.raises(ValueError): mars.Request("U", date="20200101", time="0000", model="undef") - - -def test_no_defaults(): - observed = mars.Request("U").dump(exclude_defaults=True) - expected = {"param": "U"} - - assert observed == expected