Skip to content
This repository was archived by the owner on Mar 24, 2024. It is now read-only.

Commit c67a6eb

Browse files
author
Luca
committed
Add default to the dataclass builder
1 parent 27e7ff1 commit c67a6eb

File tree

4 files changed

+85
-41
lines changed

4 files changed

+85
-41
lines changed

openapi/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
"""Minimal OpenAPI asynchronous server application"""
2-
__version__ = "2.7.2"
2+
__version__ = "2.8.0"

openapi/data/db.py

Lines changed: 80 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1-
import typing as t
21
from dataclasses import make_dataclass
32
from datetime import date, datetime
43
from decimal import Decimal
54
from functools import partial
5+
from typing import Dict, List, Optional, Sequence, Set, Tuple, Union, cast
66

77
import sqlalchemy as sa
88
from sqlalchemy_utils import UUIDType
@@ -16,35 +16,54 @@ def dataclass_from_table(
1616
name: str,
1717
table: sa.Table,
1818
*,
19-
exclude: t.Optional[t.Sequence[str]] = None,
20-
include: t.Optional[t.Sequence[str]] = None,
21-
required: bool = False,
22-
ops: t.Optional[t.Dict[str, t.Sequence[str]]] = None,
19+
exclude: Optional[Sequence[str]] = None,
20+
include: Optional[Sequence[str]] = None,
21+
default: Union[bool, Sequence[str]] = False,
22+
required: Union[bool, Sequence[str]] = False,
23+
ops: Optional[Dict[str, Sequence[str]]] = None,
2324
) -> type:
2425
"""Create a dataclass from an :class:`sqlalchemy.schema.Table`
2526
2627
:param name: dataclass name
2728
:param table: sqlalchemy table
2829
:param exclude: fields to exclude from the dataclass
2930
:param include: fields to include in the dataclass
30-
:param required: set all non nullable columns as required fields in the dataclass
31+
:param default: use columns defaults in the dataclass
32+
:param required: set non nullable columns without a default as
33+
required fields in the dataclass
3134
:param ops: additional operation for fields
3235
"""
3336
columns = []
3437
include = set(include or table.columns.keys()) - set(exclude or ())
35-
column_ops = t.cast(t.Dict[str, t.Sequence[str]], ops or {})
38+
defaults = column_info(include, default)
39+
requireds = column_info(include, required)
40+
column_ops = cast(Dict[str, Sequence[str]], ops or {})
3641
for col in table.columns:
3742
if col.name not in include:
3843
continue
3944
ctype = type(col.type)
4045
converter = CONVERTERS.get(ctype)
4146
if not converter: # pragma: no cover
4247
raise NotImplementedError(f"Cannot convert column {col.name}: {ctype}")
43-
field = (col.name, *converter(col, required, column_ops.get(col.name, ())))
48+
required = col.name in requireds
49+
use_default = col.name in defaults
50+
field = (
51+
col.name,
52+
*converter(col, required, use_default, column_ops.get(col.name, ())),
53+
)
4454
columns.append(field)
4555
return make_dataclass(name, columns)
4656

4757

58+
def column_info(columns: Set[str], value: Union[bool, Sequence[str]]) -> Set[str]:
59+
if value is False:
60+
return set()
61+
elif value is True:
62+
return columns.copy()
63+
else:
64+
return set(value if value is not None else columns)
65+
66+
4867
def converter(*types):
4968
def _(f):
5069
for type_ in types:
@@ -55,85 +74,108 @@ def _(f):
5574

5675

5776
@converter(sa.Boolean)
58-
def bl(col: sa.Column, required: bool, ops: t.Sequence[str]) -> t.Tuple:
77+
def bl(col: sa.Column, required: bool, use_default: bool, ops: Sequence[str]) -> Tuple:
5978
data_field = col.info.get("data_field", fields.bool_field)
60-
return (bool, data_field(**info(col, required, ops)))
79+
return (bool, data_field(**info(col, required, use_default, ops)))
6180

6281

6382
@converter(sa.Integer)
64-
def integer(col: sa.Column, required: bool, ops: t.Sequence[str]) -> t.Tuple:
83+
def integer(
84+
col: sa.Column, required: bool, use_default: bool, ops: Sequence[str]
85+
) -> Tuple:
6586
data_field = col.info.get("data_field", fields.number_field)
66-
return (int, data_field(precision=0, **info(col, required, ops)))
87+
return (int, data_field(precision=0, **info(col, required, use_default, ops)))
6788

6889

6990
@converter(sa.Numeric)
70-
def number(col: sa.Column, required: bool, ops: t.Sequence[str]) -> t.Tuple:
91+
def number(
92+
col: sa.Column, required: bool, use_default: bool, ops: Sequence[str]
93+
) -> Tuple:
7194
data_field = col.info.get("data_field", fields.decimal_field)
72-
return (Decimal, data_field(precision=col.type.scale, **info(col, required, ops)))
95+
return (
96+
Decimal,
97+
data_field(precision=col.type.scale, **info(col, required, use_default, ops)),
98+
)
7399

74100

75101
@converter(sa.String, sa.Text, sa.CHAR, sa.VARCHAR)
76-
def string(col: sa.Column, required: bool, ops: t.Sequence[str]) -> t.Tuple:
102+
def string(
103+
col: sa.Column, required: bool, use_default: bool, ops: Sequence[str]
104+
) -> Tuple:
77105
data_field = col.info.get("data_field", fields.str_field)
78106
return (
79107
str,
80-
data_field(max_length=col.type.length or 0, **info(col, required, ops)),
108+
data_field(
109+
max_length=col.type.length or 0, **info(col, required, use_default, ops)
110+
),
81111
)
82112

83113

84114
@converter(sa.DateTime)
85-
def dt_ti(col: sa.Column, required: bool, ops: t.Sequence[str]) -> t.Tuple:
115+
def dt_ti(
116+
col: sa.Column, required: bool, use_default: bool, ops: Sequence[str]
117+
) -> Tuple:
86118
data_field = col.info.get("data_field", fields.date_time_field)
87119
return (
88120
datetime,
89-
data_field(timezone=col.type.timezone, **info(col, required, ops)),
121+
data_field(timezone=col.type.timezone, **info(col, required, use_default, ops)),
90122
)
91123

92124

93125
@converter(sa.Date)
94-
def dt(col: sa.Column, required: bool, ops: t.Sequence[str]) -> t.Tuple:
126+
def dt(col: sa.Column, required: bool, use_default: bool, ops: Sequence[str]) -> Tuple:
95127
data_field = col.info.get("data_field", fields.date_field)
96-
return (date, data_field(**info(col, required, ops)))
128+
return (date, data_field(**info(col, required, use_default, ops)))
97129

98130

99131
@converter(sa.Enum)
100-
def en(col: sa.Column, required: bool, ops: t.Sequence[str]) -> t.Tuple:
132+
def en(col: sa.Column, required: bool, use_default: bool, ops: Sequence[str]) -> Tuple:
101133
data_field = col.info.get("data_field", fields.enum_field)
102134
return (
103135
col.type.enum_class,
104-
data_field(col.type.enum_class, **info(col, required, ops)),
136+
data_field(col.type.enum_class, **info(col, required, use_default, ops)),
105137
)
106138

107139

108140
@converter(sa.JSON)
109-
def js(col: sa.Column, required: bool, ops: t.Sequence[str]) -> t.Tuple:
141+
def js(col: sa.Column, required: bool, use_default: bool, ops: Sequence[str]) -> Tuple:
110142
data_field = col.info.get("data_field", fields.json_field)
111143
val = None
112144
if col.default:
113145
arg = col.default.arg
114146
val = arg() if col.default.is_callable else arg
115-
return (JsonTypes.get(type(val), t.Dict), data_field(**info(col, required, ops)))
147+
return (
148+
JsonTypes.get(type(val), Dict),
149+
data_field(**info(col, required, use_default, ops)),
150+
)
116151

117152

118153
@converter(UUIDType)
119-
def uuid(col: sa.Column, required: bool, ops: t.Sequence[str]) -> t.Tuple:
154+
def uuid(
155+
col: sa.Column, required: bool, use_default: bool, ops: Sequence[str]
156+
) -> Tuple:
120157
data_field = col.info.get("data_field", fields.uuid_field)
121-
return (str, data_field(**info(col, required, ops)))
158+
return (str, data_field(**info(col, required, use_default, ops)))
122159

123160

124-
def info(col: sa.Column, required: bool, ops: t.Sequence[str]) -> t.Tuple:
161+
def info(
162+
col: sa.Column, required: bool, use_default: bool, ops: Sequence[str]
163+
) -> Tuple:
125164
data = dict(ops=ops)
126-
default = col.default.arg if col.default is not None else None
127-
if callable(default):
128-
data.update(default_factory=partial(default, None))
129-
required = False
130-
elif isinstance(default, (list, dict, set)):
131-
data.update(default_factory=lambda: default.copy())
132-
required = False
133-
else:
134-
data.update(default=default)
135-
if required and (col.nullable or default is not None):
165+
if use_default:
166+
default = col.default.arg if col.default is not None else None
167+
if callable(default):
168+
data.update(default_factory=partial(default, None))
169+
required = False
170+
elif isinstance(default, (list, dict, set)):
171+
data.update(default_factory=lambda: default.copy())
136172
required = False
173+
else:
174+
data.update(default=default)
175+
if required and (col.nullable or default is not None):
176+
required = False
177+
elif required and col.nullable:
178+
required = False
137179
data.update(required=required)
138180
if col.doc:
139181
data.update(description=col.doc)
@@ -142,4 +184,4 @@ def info(col: sa.Column, required: bool, ops: t.Sequence[str]) -> t.Tuple:
142184
return data
143185

144186

145-
JsonTypes = {list: t.List, dict: t.Dict}
187+
JsonTypes = {list: List, dict: Dict}

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "aio-openapi"
3-
version = "2.7.2"
3+
version = "2.8.0"
44
description = "Minimal OpenAPI asynchronous server application"
55
documentation = "https://aio-openapi.readthedocs.io"
66
repository = "https://github.com/quantmind/aio-openapi"

tests/example/models.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,9 @@
1313

1414
@dataclass
1515
class TaskAdd(
16-
dataclass_from_table("_TaskAdd", DB.tasks, required=True, exclude=("id", "done"))
16+
dataclass_from_table(
17+
"_TaskAdd", DB.tasks, required=True, default=True, exclude=("id", "done")
18+
)
1719
):
1820
@classmethod
1921
def validate(cls, data, errors):

0 commit comments

Comments
 (0)