1
- import typing as t
2
1
from dataclasses import make_dataclass
3
2
from datetime import date , datetime
4
3
from decimal import Decimal
5
4
from functools import partial
5
+ from typing import Dict , List , Optional , Sequence , Set , Tuple , Union , cast
6
6
7
7
import sqlalchemy as sa
8
8
from sqlalchemy_utils import UUIDType
@@ -16,35 +16,54 @@ def dataclass_from_table(
16
16
name : str ,
17
17
table : sa .Table ,
18
18
* ,
19
- exclude : t .Optional [t .Sequence [str ]] = None ,
20
- include : t .Optional [t .Sequence [str ]] = None ,
21
- required : bool = False ,
22
- ops : t .Optional [t .Dict [str , t .Sequence [str ]]] = None ,
19
+ exclude : Optional [Sequence [str ]] = None ,
20
+ include : Optional [Sequence [str ]] = None ,
21
+ default : Union [bool , Sequence [str ]] = False ,
22
+ required : Union [bool , Sequence [str ]] = False ,
23
+ ops : Optional [Dict [str , Sequence [str ]]] = None ,
23
24
) -> type :
24
25
"""Create a dataclass from an :class:`sqlalchemy.schema.Table`
25
26
26
27
:param name: dataclass name
27
28
:param table: sqlalchemy table
28
29
:param exclude: fields to exclude from the dataclass
29
30
:param include: fields to include in the dataclass
30
- :param required: set all non nullable columns as required fields in the dataclass
31
+ :param default: use columns defaults in the dataclass
32
+ :param required: set non nullable columns without a default as
33
+ required fields in the dataclass
31
34
:param ops: additional operation for fields
32
35
"""
33
36
columns = []
34
37
include = set (include or table .columns .keys ()) - set (exclude or ())
35
- column_ops = t .cast (t .Dict [str , t .Sequence [str ]], ops or {})
38
+ defaults = column_info (include , default )
39
+ requireds = column_info (include , required )
40
+ column_ops = cast (Dict [str , Sequence [str ]], ops or {})
36
41
for col in table .columns :
37
42
if col .name not in include :
38
43
continue
39
44
ctype = type (col .type )
40
45
converter = CONVERTERS .get (ctype )
41
46
if not converter : # pragma: no cover
42
47
raise NotImplementedError (f"Cannot convert column { col .name } : { ctype } " )
43
- field = (col .name , * converter (col , required , column_ops .get (col .name , ())))
48
+ required = col .name in requireds
49
+ use_default = col .name in defaults
50
+ field = (
51
+ col .name ,
52
+ * converter (col , required , use_default , column_ops .get (col .name , ())),
53
+ )
44
54
columns .append (field )
45
55
return make_dataclass (name , columns )
46
56
47
57
58
+ def column_info (columns : Set [str ], value : Union [bool , Sequence [str ]]) -> Set [str ]:
59
+ if value is False :
60
+ return set ()
61
+ elif value is True :
62
+ return columns .copy ()
63
+ else :
64
+ return set (value if value is not None else columns )
65
+
66
+
48
67
def converter (* types ):
49
68
def _ (f ):
50
69
for type_ in types :
@@ -55,85 +74,108 @@ def _(f):
55
74
56
75
57
76
@converter (sa .Boolean )
58
- def bl (col : sa .Column , required : bool , ops : t . Sequence [str ]) -> t . Tuple :
77
+ def bl (col : sa .Column , required : bool , use_default : bool , ops : Sequence [str ]) -> Tuple :
59
78
data_field = col .info .get ("data_field" , fields .bool_field )
60
- return (bool , data_field (** info (col , required , ops )))
79
+ return (bool , data_field (** info (col , required , use_default , ops )))
61
80
62
81
63
82
@converter (sa .Integer )
64
- def integer (col : sa .Column , required : bool , ops : t .Sequence [str ]) -> t .Tuple :
83
+ def integer (
84
+ col : sa .Column , required : bool , use_default : bool , ops : Sequence [str ]
85
+ ) -> Tuple :
65
86
data_field = col .info .get ("data_field" , fields .number_field )
66
- return (int , data_field (precision = 0 , ** info (col , required , ops )))
87
+ return (int , data_field (precision = 0 , ** info (col , required , use_default , ops )))
67
88
68
89
69
90
@converter (sa .Numeric )
70
- def number (col : sa .Column , required : bool , ops : t .Sequence [str ]) -> t .Tuple :
91
+ def number (
92
+ col : sa .Column , required : bool , use_default : bool , ops : Sequence [str ]
93
+ ) -> Tuple :
71
94
data_field = col .info .get ("data_field" , fields .decimal_field )
72
- return (Decimal , data_field (precision = col .type .scale , ** info (col , required , ops )))
95
+ return (
96
+ Decimal ,
97
+ data_field (precision = col .type .scale , ** info (col , required , use_default , ops )),
98
+ )
73
99
74
100
75
101
@converter (sa .String , sa .Text , sa .CHAR , sa .VARCHAR )
76
- def string (col : sa .Column , required : bool , ops : t .Sequence [str ]) -> t .Tuple :
102
+ def string (
103
+ col : sa .Column , required : bool , use_default : bool , ops : Sequence [str ]
104
+ ) -> Tuple :
77
105
data_field = col .info .get ("data_field" , fields .str_field )
78
106
return (
79
107
str ,
80
- data_field (max_length = col .type .length or 0 , ** info (col , required , ops )),
108
+ data_field (
109
+ max_length = col .type .length or 0 , ** info (col , required , use_default , ops )
110
+ ),
81
111
)
82
112
83
113
84
114
@converter (sa .DateTime )
85
- def dt_ti (col : sa .Column , required : bool , ops : t .Sequence [str ]) -> t .Tuple :
115
+ def dt_ti (
116
+ col : sa .Column , required : bool , use_default : bool , ops : Sequence [str ]
117
+ ) -> Tuple :
86
118
data_field = col .info .get ("data_field" , fields .date_time_field )
87
119
return (
88
120
datetime ,
89
- data_field (timezone = col .type .timezone , ** info (col , required , ops )),
121
+ data_field (timezone = col .type .timezone , ** info (col , required , use_default , ops )),
90
122
)
91
123
92
124
93
125
@converter (sa .Date )
94
- def dt (col : sa .Column , required : bool , ops : t . Sequence [str ]) -> t . Tuple :
126
+ def dt (col : sa .Column , required : bool , use_default : bool , ops : Sequence [str ]) -> Tuple :
95
127
data_field = col .info .get ("data_field" , fields .date_field )
96
- return (date , data_field (** info (col , required , ops )))
128
+ return (date , data_field (** info (col , required , use_default , ops )))
97
129
98
130
99
131
@converter (sa .Enum )
100
- def en (col : sa .Column , required : bool , ops : t . Sequence [str ]) -> t . Tuple :
132
+ def en (col : sa .Column , required : bool , use_default : bool , ops : Sequence [str ]) -> Tuple :
101
133
data_field = col .info .get ("data_field" , fields .enum_field )
102
134
return (
103
135
col .type .enum_class ,
104
- data_field (col .type .enum_class , ** info (col , required , ops )),
136
+ data_field (col .type .enum_class , ** info (col , required , use_default , ops )),
105
137
)
106
138
107
139
108
140
@converter (sa .JSON )
109
- def js (col : sa .Column , required : bool , ops : t . Sequence [str ]) -> t . Tuple :
141
+ def js (col : sa .Column , required : bool , use_default : bool , ops : Sequence [str ]) -> Tuple :
110
142
data_field = col .info .get ("data_field" , fields .json_field )
111
143
val = None
112
144
if col .default :
113
145
arg = col .default .arg
114
146
val = arg () if col .default .is_callable else arg
115
- return (JsonTypes .get (type (val ), t .Dict ), data_field (** info (col , required , ops )))
147
+ return (
148
+ JsonTypes .get (type (val ), Dict ),
149
+ data_field (** info (col , required , use_default , ops )),
150
+ )
116
151
117
152
118
153
@converter (UUIDType )
119
- def uuid (col : sa .Column , required : bool , ops : t .Sequence [str ]) -> t .Tuple :
154
+ def uuid (
155
+ col : sa .Column , required : bool , use_default : bool , ops : Sequence [str ]
156
+ ) -> Tuple :
120
157
data_field = col .info .get ("data_field" , fields .uuid_field )
121
- return (str , data_field (** info (col , required , ops )))
158
+ return (str , data_field (** info (col , required , use_default , ops )))
122
159
123
160
124
- def info (col : sa .Column , required : bool , ops : t .Sequence [str ]) -> t .Tuple :
161
+ def info (
162
+ col : sa .Column , required : bool , use_default : bool , ops : Sequence [str ]
163
+ ) -> Tuple :
125
164
data = dict (ops = ops )
126
- default = col .default .arg if col .default is not None else None
127
- if callable (default ):
128
- data .update (default_factory = partial (default , None ))
129
- required = False
130
- elif isinstance (default , (list , dict , set )):
131
- data .update (default_factory = lambda : default .copy ())
132
- required = False
133
- else :
134
- data .update (default = default )
135
- if required and (col .nullable or default is not None ):
165
+ if use_default :
166
+ default = col .default .arg if col .default is not None else None
167
+ if callable (default ):
168
+ data .update (default_factory = partial (default , None ))
169
+ required = False
170
+ elif isinstance (default , (list , dict , set )):
171
+ data .update (default_factory = lambda : default .copy ())
136
172
required = False
173
+ else :
174
+ data .update (default = default )
175
+ if required and (col .nullable or default is not None ):
176
+ required = False
177
+ elif required and col .nullable :
178
+ required = False
137
179
data .update (required = required )
138
180
if col .doc :
139
181
data .update (description = col .doc )
@@ -142,4 +184,4 @@ def info(col: sa.Column, required: bool, ops: t.Sequence[str]) -> t.Tuple:
142
184
return data
143
185
144
186
145
- JsonTypes = {list : t . List , dict : t . Dict }
187
+ JsonTypes = {list : List , dict : Dict }
0 commit comments