From 66715be4b90cb7270d85500d7f7f0185b1f0e8eb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=B4me=20Lafr=C3=A9choux?= Date: Wed, 10 Jul 2024 09:49:44 +0200 Subject: [PATCH 1/4] Modify timeseries_property_data fixture: alternate min/max --- tests/conftest.py | 35 +++++++++---------- .../test_timeseries_property_data.py | 6 ++-- 2 files changed, 20 insertions(+), 21 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 42dd8de..04120fd 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -230,14 +230,13 @@ def timeseries(request, app, campaigns, campaign_scopes): with OpenBar(): ts_l = [] for i in range(request.param): - ts_i = model.Timeseries( + ts_i = model.Timeseries.new( name=f"Timeseries {i}", description=f"Test timeseries #{i}", campaign_id=campaigns[i % len(campaigns)], campaign_scope_id=campaign_scopes[i % len(campaign_scopes)], ) ts_l.append(ts_i) - db.session.add_all(ts_l) db.session.commit() return [ts.id for ts in ts_l] @@ -246,22 +245,23 @@ def timeseries(request, app, campaigns, campaign_scopes): def timeseries_property_data(request, app, timeseries_properties, timeseries): with OpenBar(): tspd_l = [] - for ts in timeseries: - tspd_l.append( - model.TimeseriesPropertyData( - timeseries_id=ts, - property_id=timeseries_properties[0], - value=12, + for idx, ts in enumerate(timeseries): + if idx % 2 == 0: + tspd_l.append( + model.TimeseriesPropertyData.new( + timeseries_id=ts, + property_id=timeseries_properties[0], + value=12, + ) ) - ) - tspd_l.append( - model.TimeseriesPropertyData( - timeseries_id=ts, - property_id=timeseries_properties[1], - value=42, + else: + tspd_l.append( + model.TimeseriesPropertyData.new( + timeseries_id=ts, + property_id=timeseries_properties[1], + value=42, + ) ) - ) - db.session.add_all(tspd_l) db.session.commit() return [tspd.id for tspd in tspd_l] @@ -271,12 +271,11 @@ def timeseries_by_data_states(request, app, timeseries): with OpenBar(): ts_l = [] for i in range(request.param): - ts_i = model.TimeseriesByDataState( + ts_i = model.TimeseriesByDataState.new( timeseries_id=timeseries[i % len(timeseries)], data_state_id=1, ) ts_l.append(ts_i) - db.session.add_all(ts_l) db.session.commit() return [ts.id for ts in ts_l] diff --git a/tests/resources/test_timeseries_property_data.py b/tests/resources/test_timeseries_property_data.py index b4e72f8..08332f4 100644 --- a/tests/resources/test_timeseries_property_data.py +++ b/tests/resources/test_timeseries_property_data.py @@ -154,7 +154,7 @@ def test_timeseries_property_data_as_user_api( tsp_1_id = timeseries_properties[0] ts_1_id = timeseries[0] tspd_1_id = timeseries_property_data[0] - tspd_3_id = timeseries_property_data[2] + tspd_2_id = timeseries_property_data[1] creds = users["Active"]["creds"] @@ -165,7 +165,7 @@ def test_timeseries_property_data_as_user_api( ret = client.get(TIMESERIES_PROPERTY_DATA_URL) assert ret.status_code == 200 ret_val = ret.json - assert len(ret_val) == 2 + assert len(ret_val) == 1 assert all([tspd["timeseries_id"] == ts_1_id for tspd in ret_val]) # POST @@ -188,7 +188,7 @@ def test_timeseries_property_data_as_user_api( tspd_1 = ret_val # GET by id, user not in group - ret = client.get(f"{TIMESERIES_PROPERTY_DATA_URL}{tspd_3_id}") + ret = client.get(f"{TIMESERIES_PROPERTY_DATA_URL}{tspd_2_id}") assert ret.status_code == 403 # PUT From 388d51046dfa75e79b2ed2c2d2cdf1ea221b2174 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=B4me=20Lafr=C3=A9choux?= Date: Wed, 10 Jul 2024 15:00:02 +0200 Subject: [PATCH 2/4] Add DictStr field --- src/bemserver_api/extensions/ma_fields.py | 22 +++++++++++++++ src/bemserver_api/extensions/smorest.py | 34 +++++++++++++++++++++-- tests/extensions/test_ma_fields.py | 12 +++++++- 3 files changed, 65 insertions(+), 3 deletions(-) diff --git a/src/bemserver_api/extensions/ma_fields.py b/src/bemserver_api/extensions/ma_fields.py index b765b5c..61bd1aa 100644 --- a/src/bemserver_api/extensions/ma_fields.py +++ b/src/bemserver_api/extensions/ma_fields.py @@ -80,3 +80,25 @@ def __init__(self, fields, **kwargs): [v for f in fields for v in [f, f"+{f}", f"-{f}"]] ) super().__init__(ma.fields.String(validate=validator), **kwargs) + + +class DictStr(ma.fields.Dict): + default_error_messages = { + "invalid": "Not a valid string.", + "invalid_utf8": "Not a valid utf-8 string.", + "invalid_json": "Not a valid json object.", + } + + def _deserialize(self, value, attr, data, **kwargs): + if not isinstance(value, (str, bytes)): + raise self.make_error("invalid") + try: + if isinstance(value, bytes): + value = value.decode("utf-8") + except UnicodeDecodeError as exc: + raise self.make_error("invalid_utf8") from exc + try: + value = json.loads(value) + except json.decoder.JSONDecodeError as exc: + raise self.make_error("invalid_json") from exc + return super()._deserialize(value, attr, data, **kwargs) diff --git a/src/bemserver_api/extensions/smorest.py b/src/bemserver_api/extensions/smorest.py index 7707a64..f1e2528 100644 --- a/src/bemserver_api/extensions/smorest.py +++ b/src/bemserver_api/extensions/smorest.py @@ -10,14 +10,15 @@ import flask_smorest import marshmallow as ma import marshmallow_sqlalchemy as msa -from apispec.ext.marshmallow import MarshmallowPlugin +from apispec.ext.marshmallow import MarshmallowPlugin as MarshmallowPluginOrig +from apispec.ext.marshmallow import OpenAPIConverter as OrigOpenAPIConverter from apispec.ext.marshmallow.common import resolve_schema_cls from bemserver_core.authorization import get_current_user from . import integrity_error from .authentication import auth -from .ma_fields import Timezone +from .ma_fields import DictStr, Timezone def resolver(schema): @@ -44,6 +45,35 @@ def resolver(schema): } +class OpenAPIConverter(OrigOpenAPIConverter): + def _field2parameter(self, field, *, name, location): + ret: dict = {"in": location, "name": name} + + prop = self.field2property(field) + if self.openapi_version.major < 3: + ret.update(prop) + else: + if "description" in prop: + ret["description"] = prop.pop("description") + if "deprecated" in prop: + ret["deprecated"] = prop.pop("deprecated") + ret["schema"] = prop + + # Document DictStr as "content" parameter + # https://github.com/marshmallow-code/apispec/issues/922 + if isinstance(field, DictStr): + ret["content"] = {"application/json": ret.pop("schema")} + + for param_attr_func in self.parameter_attribute_functions: + ret.update(param_attr_func(field, ret=ret)) + + return ret + + +class MarshmallowPlugin(MarshmallowPluginOrig): + Converter = OpenAPIConverter + + class Api(flask_smorest.Api): """Api class""" diff --git a/tests/extensions/test_ma_fields.py b/tests/extensions/test_ma_fields.py index 1b6f814..63085bd 100644 --- a/tests/extensions/test_ma_fields.py +++ b/tests/extensions/test_ma_fields.py @@ -2,7 +2,7 @@ import marshmallow as ma -from bemserver_api.extensions.ma_fields import Timezone, UnitSymbol +from bemserver_api.extensions.ma_fields import DictStr, Timezone, UnitSymbol class TestMaFields: @@ -25,3 +25,13 @@ def test_ma_fields_unit_symbol(self): field.deserialize("wh") with pytest.raises(ma.ValidationError): field.deserialize("dummy") + + def test_ma_fields_dictstr(self): + field = DictStr() + assert field.deserialize('{"lol": "rofl"}') == {"lol": "rofl"} + with pytest.raises(ma.ValidationError, match="Not a valid string."): + field.deserialize(12) + with pytest.raises(ma.ValidationError, match="Not a valid utf-8 string."): + field.deserialize(b"\xf3") + with pytest.raises(ma.ValidationError, match="Not a valid json object."): + field.deserialize("{'lol': 'rofl'}") From 366069ca3c23896cd00debffecee03f7b06aaf14 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=B4me=20Lafr=C3=A9choux?= Date: Wed, 10 Jul 2024 16:18:07 +0200 Subject: [PATCH 3/4] Require bemserver-core 0.18.1 --- pyproject.toml | 2 +- requirements/install.txt | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index e5d4c0a..3666d2c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,7 +33,7 @@ dependencies = [ "flask_smorest>=0.43.0,<0.44", "apispec>=6.1.0,<7.0", "authlib>=1.3.0,<2.0", - "bemserver-core>=0.18.0,<0.19", + "bemserver-core>=0.18.1,<0.19", ] [project.urls] diff --git a/requirements/install.txt b/requirements/install.txt index 7986bab..b9077cf 100644 --- a/requirements/install.txt +++ b/requirements/install.txt @@ -20,7 +20,7 @@ async-timeout==4.0.3 # via redis authlib==1.3.1 # via bemserver-api (pyproject.toml) -bemserver-core==0.18.0 +bemserver-core==0.18.1 # via bemserver-api (pyproject.toml) billiard==4.2.0 # via celery From 5e7e067f33534b652259fb671b89048f133ba311 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=B4me=20Lafr=C3=A9choux?= Date: Wed, 10 Jul 2024 16:16:12 +0200 Subject: [PATCH 4/4] Allow filtering timeseries by property data --- .../resources/timeseries/schemas.py | 3 ++ tests/resources/test_timeseries.py | 36 +++++++++++++++++++ 2 files changed, 39 insertions(+) diff --git a/src/bemserver_api/resources/timeseries/schemas.py b/src/bemserver_api/resources/timeseries/schemas.py index 4e0cf7c..e0f62ca 100644 --- a/src/bemserver_api/resources/timeseries/schemas.py +++ b/src/bemserver_api/resources/timeseries/schemas.py @@ -40,6 +40,9 @@ class TimeseriesQueryArgsSchema(Schema): space_id = ma.fields.Int() zone_id = ma.fields.Int() event_id = ma.fields.Int() + properties = ma_fields.DictStr( + ma.fields.String(), ma.fields.Str(), metadata={"example": {"Min": "0"}} + ) @ma.validates_schema def validate_conflicting_fields(self, data, **kwargs): diff --git a/tests/resources/test_timeseries.py b/tests/resources/test_timeseries.py index d6cd083..179a83b 100644 --- a/tests/resources/test_timeseries.py +++ b/tests/resources/test_timeseries.py @@ -162,6 +162,42 @@ def test_timeseries_api(self, app, users, campaigns, campaign_scopes): ret = client.get(f"{TIMESERIES_URL}{timeseries_1_id}") assert ret.status_code == 404 + @pytest.mark.usefixtures("timeseries_properties") + @pytest.mark.usefixtures("timeseries_property_data") + def test_timeseries_filter_by_properties_data_api(self, app, users): + creds = users["Chuck"]["creds"] + + client = app.test_client() + + with AuthHeader(creds): + ret = client.get(TIMESERIES_URL) + ret_val = ret.json + assert len(ret_val) == 2 + + ret = client.get( + TIMESERIES_URL, query_string={"properties": '{"Min": "12"}'} + ) + assert ret.status_code == 200 + ret_val = ret.json + assert len(ret_val) == 1 + + # Invalid property name + ret = client.get( + TIMESERIES_URL, query_string={"properties": '{"Dummy": "12"}'} + ) + assert ret.status_code == 200 + ret_val = ret.json + assert not ret_val + + # Not dicts of strings + ret = client.get( + TIMESERIES_URL, query_string={"properties": '{12: "Dummy"}'} + ) + assert ret.status_code == 422 + ret_val = ret.json + ret = client.get(TIMESERIES_URL, query_string={"properties": '{"Min": 12}'}) + assert ret.status_code == 422 + @pytest.mark.usefixtures("timeseries_by_spaces") @pytest.mark.usefixtures("timeseries_by_zones") @pytest.mark.usefixtures("timeseries_by_events")