Skip to content

Commit

Permalink
Add support for RangeFloat type
Browse files Browse the repository at this point in the history
  • Loading branch information
chaitan94 committed Sep 2, 2020
1 parent a43205c commit 40c2944
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 0 deletions.
1 change: 1 addition & 0 deletions mobilitydb_sqlalchemy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from .types.STBox import STBox

# Range Types
from .types.RangeFloat import RangeFloat
from .types.RangeInt import RangeInt

# Time Types
Expand Down
11 changes: 11 additions & 0 deletions mobilitydb_sqlalchemy/types/RangeFloat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from pymeos.range import RangeFloat as MEOSRangeFloat
from sqlalchemy.types import UserDefinedType

from .BaseType import BaseType


class RangeFloat(BaseType):
base_class = MEOSRangeFloat

def get_col_spec(self):
return "FLOATRANGE"
7 changes: 7 additions & 0 deletions tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from mobilitydb_sqlalchemy import (
Period,
PeriodSet,
RangeFloat,
RangeInt,
STBox,
TBool,
Expand All @@ -36,6 +37,12 @@ class STBoxes(Base):
stbox = Column(STBox)


class RangeFloats(Base):
__tablename__ = "rangefloat_test_001"
id = Column(Integer, primary_key=True)
rangefloat = Column(RangeFloat)


class RangeInts(Base):
__tablename__ = "rangeint_test_001"
id = Column(Integer, primary_key=True)
Expand Down
33 changes: 33 additions & 0 deletions tests/test_rangefloat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import datetime
import pytest
from pymeos.range import RangeFloat
from sqlalchemy.exc import StatementError

from .models import RangeFloats


def test_simple_insert(session):
rangefloat = RangeFloat(10, 20)

session.add(RangeFloats(rangefloat=rangefloat))
session.commit()

sql = session.query(RangeFloats).filter(RangeFloats.id == 1)
assert sql.count() == 1

results = sql.all()
for result in results:
assert result.id == 1
assert str(result.rangefloat) == "[10, 20)"
assert result.rangefloat.lower == 10
assert result.rangefloat.upper == 20
assert result.rangefloat.lower_inc == True
assert result.rangefloat.upper_inc == False


def test_str_values_are_invalid(session):
rangefloat = "[10, 20)"

with pytest.raises(StatementError):
session.add(RangeFloats(rangefloat=rangefloat))
session.commit()

0 comments on commit 40c2944

Please sign in to comment.