Skip to content
This repository has been archived by the owner on Mar 24, 2024. It is now read-only.

Commit

Permalink
Merge pull request #286 from quantmind/ls-db
Browse files Browse the repository at this point in the history
Better handling of defaults
  • Loading branch information
Luca Sbardella authored Apr 20, 2022
2 parents 67730d7 + 6dde6f0 commit a3f06a2
Show file tree
Hide file tree
Showing 15 changed files with 183 additions and 91 deletions.
2 changes: 1 addition & 1 deletion .vscode/launch.json
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
"RedirectOutput"
],
"args": [
"tests/spec/test_schema_parser.py::test_schema2json"
"tests/core/test_dc_db.py"
]
}
]
Expand Down
13 changes: 0 additions & 13 deletions mypy.ini

This file was deleted.

2 changes: 1 addition & 1 deletion openapi/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
"""Minimal OpenAPI asynchronous server application"""
__version__ = "2.7.2"
__version__ = "2.8.0"
123 changes: 88 additions & 35 deletions openapi/data/db.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import typing as t
from dataclasses import make_dataclass
from datetime import date, datetime
from decimal import Decimal
from functools import partial
from typing import Dict, List, Optional, Sequence, Set, Tuple, Union, cast

import sqlalchemy as sa
from sqlalchemy_utils import UUIDType
Expand All @@ -15,35 +16,54 @@ def dataclass_from_table(
name: str,
table: sa.Table,
*,
exclude: t.Optional[t.Sequence[str]] = None,
include: t.Optional[t.Sequence[str]] = None,
required: bool = False,
ops: t.Optional[t.Dict[str, t.Sequence[str]]] = None,
exclude: Optional[Sequence[str]] = None,
include: Optional[Sequence[str]] = None,
default: Union[bool, Sequence[str]] = False,
required: Union[bool, Sequence[str]] = False,
ops: Optional[Dict[str, Sequence[str]]] = None,
) -> type:
"""Create a dataclass from an :class:`sqlalchemy.schema.Table`
:param name: dataclass name
:param table: sqlalchemy table
:param exclude: fields to exclude from the dataclass
:param include: fields to include in the dataclass
:param required: set all non nullable columns as required fields in the dataclass
:param default: use columns defaults in the dataclass
:param required: set non nullable columns without a default as
required fields in the dataclass
:param ops: additional operation for fields
"""
columns = []
include = set(include or table.columns.keys()) - set(exclude or ())
column_ops = t.cast(t.Dict[str, t.Sequence[str]], ops or {})
defaults = column_info(include, default)
requireds = column_info(include, required)
column_ops = cast(Dict[str, Sequence[str]], ops or {})
for col in table.columns:
if col.name not in include:
continue
ctype = type(col.type)
converter = CONVERTERS.get(ctype)
if not converter: # pragma: no cover
raise NotImplementedError(f"Cannot convert column {col.name}: {ctype}")
field = (col.name, *converter(col, required, column_ops.get(col.name, ())))
required = col.name in requireds
use_default = col.name in defaults
field = (
col.name,
*converter(col, required, use_default, column_ops.get(col.name, ())),
)
columns.append(field)
return make_dataclass(name, columns)


def column_info(columns: Set[str], value: Union[bool, Sequence[str]]) -> Set[str]:
if value is False:
return set()
elif value is True:
return columns.copy()
else:
return set(value if value is not None else columns)


def converter(*types):
def _(f):
for type_ in types:
Expand All @@ -54,81 +74,114 @@ def _(f):


@converter(sa.Boolean)
def bl(col: sa.Column, required: bool, ops: t.Sequence[str]) -> t.Tuple:
def bl(col: sa.Column, required: bool, use_default: bool, ops: Sequence[str]) -> Tuple:
data_field = col.info.get("data_field", fields.bool_field)
return (bool, data_field(**info(col, required, ops)))
return (bool, data_field(**info(col, required, use_default, ops)))


@converter(sa.Integer)
def integer(col: sa.Column, required: bool, ops: t.Sequence[str]) -> t.Tuple:
def integer(
col: sa.Column, required: bool, use_default: bool, ops: Sequence[str]
) -> Tuple:
data_field = col.info.get("data_field", fields.number_field)
return (int, data_field(precision=0, **info(col, required, ops)))
return (int, data_field(precision=0, **info(col, required, use_default, ops)))


@converter(sa.Numeric)
def number(col: sa.Column, required: bool, ops: t.Sequence[str]) -> t.Tuple:
def number(
col: sa.Column, required: bool, use_default: bool, ops: Sequence[str]
) -> Tuple:
data_field = col.info.get("data_field", fields.decimal_field)
return (Decimal, data_field(precision=col.type.scale, **info(col, required, ops)))
return (
Decimal,
data_field(precision=col.type.scale, **info(col, required, use_default, ops)),
)


@converter(sa.String, sa.Text, sa.CHAR, sa.VARCHAR)
def string(col: sa.Column, required: bool, ops: t.Sequence[str]) -> t.Tuple:
def string(
col: sa.Column, required: bool, use_default: bool, ops: Sequence[str]
) -> Tuple:
data_field = col.info.get("data_field", fields.str_field)
return (
str,
data_field(max_length=col.type.length or 0, **info(col, required, ops)),
data_field(
max_length=col.type.length or 0, **info(col, required, use_default, ops)
),
)


@converter(sa.DateTime)
def dt_ti(col: sa.Column, required: bool, ops: t.Sequence[str]) -> t.Tuple:
def dt_ti(
col: sa.Column, required: bool, use_default: bool, ops: Sequence[str]
) -> Tuple:
data_field = col.info.get("data_field", fields.date_time_field)
return (
datetime,
data_field(timezone=col.type.timezone, **info(col, required, ops)),
data_field(timezone=col.type.timezone, **info(col, required, use_default, ops)),
)


@converter(sa.Date)
def dt(col: sa.Column, required: bool, ops: t.Sequence[str]) -> t.Tuple:
def dt(col: sa.Column, required: bool, use_default: bool, ops: Sequence[str]) -> Tuple:
data_field = col.info.get("data_field", fields.date_field)
return (date, data_field(**info(col, required, ops)))
return (date, data_field(**info(col, required, use_default, ops)))


@converter(sa.Enum)
def en(col: sa.Column, required: bool, ops: t.Sequence[str]) -> t.Tuple:
def en(col: sa.Column, required: bool, use_default: bool, ops: Sequence[str]) -> Tuple:
data_field = col.info.get("data_field", fields.enum_field)
return (
col.type.enum_class,
data_field(col.type.enum_class, **info(col, required, ops)),
data_field(col.type.enum_class, **info(col, required, use_default, ops)),
)


@converter(sa.JSON)
def js(col: sa.Column, required: bool, ops: t.Sequence[str]) -> t.Tuple:
def js(col: sa.Column, required: bool, use_default: bool, ops: Sequence[str]) -> Tuple:
data_field = col.info.get("data_field", fields.json_field)
val = None
if col.default:
arg = col.default.arg
val = arg() if col.default.is_callable else arg
return (JsonTypes.get(type(val), t.Dict), data_field(**info(col, required, ops)))
return (
JsonTypes.get(type(val), Dict),
data_field(**info(col, required, use_default, ops)),
)


@converter(UUIDType)
def uuid(col: sa.Column, required: bool, ops: t.Sequence[str]) -> t.Tuple:
def uuid(
col: sa.Column, required: bool, use_default: bool, ops: Sequence[str]
) -> Tuple:
data_field = col.info.get("data_field", fields.uuid_field)
return (str, data_field(**info(col, required, ops)))


def info(col: sa.Column, required: bool, ops: t.Sequence[str]) -> t.Tuple:
data = dict(
description=col.doc,
required=not col.nullable if required is not False else False,
ops=ops,
)
return (str, data_field(**info(col, required, use_default, ops)))


def info(
col: sa.Column, required: bool, use_default: bool, ops: Sequence[str]
) -> Tuple:
data = dict(ops=ops)
if use_default:
default = col.default.arg if col.default is not None else None
if callable(default):
data.update(default_factory=partial(default, None))
required = False
elif isinstance(default, (list, dict, set)):
data.update(default_factory=lambda: default.copy())
required = False
else:
data.update(default=default)
if required and (col.nullable or default is not None):
required = False
elif required and col.nullable:
required = False
data.update(required=required)
if col.doc:
data.update(description=col.doc)
data.update(col.info)
data.pop("data_field", None)
return data


JsonTypes = {list: t.List, dict: t.Dict}
JsonTypes = {list: List, dict: Dict}
1 change: 1 addition & 0 deletions openapi/data/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def data_field(
"""
if isinstance(validator, Validator) and not dump:
dump = validator.dump
# Add default None otherwisenon-default fields can follow default ones
if "default_factory" not in kwargs:
kwargs.setdefault("default", None)
meta = meta or {}
Expand Down
3 changes: 0 additions & 3 deletions openapi/db/dbmodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,9 +244,6 @@ def default_filter_field(self, field: Column, op: str, value: Any):
"""
multiple = isinstance(value, (list, tuple))

if value == "":
value = None

if multiple and op in ("eq", "ne"):
if op == "eq":
return field.in_(value)
Expand Down
2 changes: 1 addition & 1 deletion openapi/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def with_test_db(db: CrudDB) -> CrudDB:
db.drop_all_schemas()


class SingleConnDatabase(CrudDB):
class SingleConnDatabase(CrudDB): # noqa
"""Useful for speedup testing"""

def __init__(self, *args, **kwargs) -> None:
Expand Down
38 changes: 37 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 4 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "aio-openapi"
version = "2.7.2"
version = "2.8.0"
description = "Minimal OpenAPI asynchronous server application"
documentation = "https://aio-openapi.readthedocs.io"
repository = "https://github.com/quantmind/aio-openapi"
Expand Down Expand Up @@ -83,6 +83,9 @@ pytest-cov = "^3.0.0"
pytest-aiohttp = "^0.3.0"
pytest-mock = "^3.6.1"
isort = "^5.10.1"
types-pytz = "^2021.3.6"
types-simplejson = "^3.17.5"
types-python-dateutil = "^2.8.11"

[tool.poetry.extras]
dev = ["aiodns", "PyJWT", "colorlog", "phonenumbers", "cchardet"]
Expand Down
1 change: 0 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ multi_line_output=3
include_trailing_comma=True

[mypy]
python_version = 3.7
plugins = sqlmypy
ignore_missing_imports=True
disallow_untyped_calls=False
Expand Down
Loading

0 comments on commit a3f06a2

Please sign in to comment.