From 40c2944344e0ba78d9211e69b03a75545ffc3e96 Mon Sep 17 00:00:00 2001 From: B Krishna Chaitanya Date: Wed, 2 Sep 2020 10:39:02 +0530 Subject: [PATCH] Add support for RangeFloat type --- mobilitydb_sqlalchemy/__init__.py | 1 + mobilitydb_sqlalchemy/types/RangeFloat.py | 11 ++++++++ tests/models.py | 7 +++++ tests/test_rangefloat.py | 33 +++++++++++++++++++++++ 4 files changed, 52 insertions(+) create mode 100644 mobilitydb_sqlalchemy/types/RangeFloat.py create mode 100644 tests/test_rangefloat.py diff --git a/mobilitydb_sqlalchemy/__init__.py b/mobilitydb_sqlalchemy/__init__.py index c44275c..c62b767 100644 --- a/mobilitydb_sqlalchemy/__init__.py +++ b/mobilitydb_sqlalchemy/__init__.py @@ -3,6 +3,7 @@ from .types.STBox import STBox # Range Types +from .types.RangeFloat import RangeFloat from .types.RangeInt import RangeInt # Time Types diff --git a/mobilitydb_sqlalchemy/types/RangeFloat.py b/mobilitydb_sqlalchemy/types/RangeFloat.py new file mode 100644 index 0000000..194c5f4 --- /dev/null +++ b/mobilitydb_sqlalchemy/types/RangeFloat.py @@ -0,0 +1,11 @@ +from pymeos.range import RangeFloat as MEOSRangeFloat +from sqlalchemy.types import UserDefinedType + +from .BaseType import BaseType + + +class RangeFloat(BaseType): + base_class = MEOSRangeFloat + + def get_col_spec(self): + return "FLOATRANGE" diff --git a/tests/models.py b/tests/models.py index df9863a..d66e1a7 100644 --- a/tests/models.py +++ b/tests/models.py @@ -10,6 +10,7 @@ from mobilitydb_sqlalchemy import ( Period, PeriodSet, + RangeFloat, RangeInt, STBox, TBool, @@ -36,6 +37,12 @@ class STBoxes(Base): stbox = Column(STBox) +class RangeFloats(Base): + __tablename__ = "rangefloat_test_001" + id = Column(Integer, primary_key=True) + rangefloat = Column(RangeFloat) + + class RangeInts(Base): __tablename__ = "rangeint_test_001" id = Column(Integer, primary_key=True) diff --git a/tests/test_rangefloat.py b/tests/test_rangefloat.py new file mode 100644 index 0000000..3b45995 --- /dev/null +++ b/tests/test_rangefloat.py @@ -0,0 +1,33 @@ +import datetime +import pytest +from pymeos.range import RangeFloat +from sqlalchemy.exc import StatementError + +from .models import RangeFloats + + +def test_simple_insert(session): + rangefloat = RangeFloat(10, 20) + + session.add(RangeFloats(rangefloat=rangefloat)) + session.commit() + + sql = session.query(RangeFloats).filter(RangeFloats.id == 1) + assert sql.count() == 1 + + results = sql.all() + for result in results: + assert result.id == 1 + assert str(result.rangefloat) == "[10, 20)" + assert result.rangefloat.lower == 10 + assert result.rangefloat.upper == 20 + assert result.rangefloat.lower_inc == True + assert result.rangefloat.upper_inc == False + + +def test_str_values_are_invalid(session): + rangefloat = "[10, 20)" + + with pytest.raises(StatementError): + session.add(RangeFloats(rangefloat=rangefloat)) + session.commit()