1
1
from __future__ import annotations
2
2
3
- import dataclasses
4
3
from collections .abc import Callable , Iterable , Iterator , Sequence
5
4
from functools import partial
6
5
from inspect import isclass
10
9
Literal ,
11
10
TypeVar ,
12
11
Union ,
13
- get_origin ,
14
- get_type_hints ,
15
12
)
16
13
17
14
import numpy as np
21
18
GetCoreSchemaHandler ,
22
19
TypeAdapter ,
23
20
)
24
- from pydantic .dataclasses import rebuild_dataclass
25
21
from pydantic .fields import FieldInfo
22
+ from pydantic_core .core_schema import tagged_union_schema
26
23
27
24
__all__ = [
28
25
"if_instance_do" ,
@@ -116,7 +113,9 @@ def calculate(self) -> int:
116
113
"""
117
114
tagged_union = _TaggedUnion (cls , discriminator )
118
115
_tagged_unions [cls ] = tagged_union
119
- cls .__init_subclass__ = classmethod (partial (__init_subclass__ , discriminator ))
116
+ cls .__init_subclass__ = classmethod (
117
+ partial (__init_subclass__ , tagged_union , discriminator )
118
+ )
120
119
cls .__get_pydantic_core_schema__ = classmethod (
121
120
partial (__get_pydantic_core_schema__ , tagged_union = tagged_union )
122
121
)
@@ -126,35 +125,13 @@ def calculate(self) -> int:
126
125
T = TypeVar ("T" , type , Callable )
127
126
128
127
129
- def deserialize_as (cls , obj ):
130
- return _tagged_unions [cls ].type_adapter .validate_python (obj )
131
-
132
-
133
- def uses_tagged_union (cls_or_func : T ) -> T :
134
- """
135
- Decorator that processes the type hints of a class or function to detect and
136
- register any tagged unions. If a tagged union is detected in the type hints,
137
- it registers the class or function as a referrer to that tagged union.
138
- Args:
139
- cls_or_func (T): The class or function to be processed for tagged unions.
140
- Returns:
141
- T: The original class or function, unmodified.
142
- """
143
- for k , v in get_type_hints (cls_or_func ).items ():
144
- tagged_union = _tagged_unions .get (get_origin (v ) or v , None )
145
- if tagged_union :
146
- tagged_union .add_referrer (cls_or_func , k )
147
- return cls_or_func
148
-
149
-
150
128
class _TaggedUnion :
151
129
def __init__ (self , base_class : type , discriminator : str ):
152
130
self ._base_class = base_class
153
131
# The members of the tagged union, i.e. subclasses of the baseclasses
154
132
self ._members : list [type ] = []
155
133
# Classes and their field names that refer to this tagged union
156
- self ._referrers : dict [type | Callable , set [str ]] = {}
157
- self .type_adapter : TypeAdapter = TypeAdapter (None )
134
+ self .type_adapter : TypeAdapter | None = None
158
135
self ._discriminator = discriminator
159
136
160
137
def _make_union (self ):
@@ -173,55 +150,36 @@ def _set_discriminator(self, cls: type | Callable, field_name: str, field: Any):
173
150
def add_member (self , cls : type ):
174
151
if cls in self ._members :
175
152
# A side effect of hooking to __get_pydantic_core_schema__ is that it is
176
- # called muliple times for the same member, do no process if it wouldn't
153
+ # called multiple times for the same member, do no process if it wouldn't
177
154
# change the member list
178
155
return
179
156
180
157
self ._members .append (cls )
181
158
union = self ._make_union ()
182
159
if union :
183
- # There are more than 1 subclasses in the union, so set all the referrers
184
- # to use this union
185
- for referrer , fields in self ._referrers .items ():
186
- if isclass (referrer ):
187
- for field in dataclasses .fields (referrer ):
188
- if field .name in fields :
189
- field .type = union
190
- self ._set_discriminator (referrer , field .name , field .default )
191
- rebuild_dataclass (referrer , force = True )
192
- # Make a type adapter for use in deserialization
193
160
self .type_adapter = TypeAdapter (union )
194
161
195
- def add_referrer (self , cls : type | Callable , attr_name : str ):
196
- self ._referrers .setdefault (cls , set ()).add (attr_name )
197
- union = self ._make_union ()
198
- if union :
199
- # There are more than 1 subclasses in the union, so set the referrer
200
- # (which is currently being constructed) to use it
201
- # note that we use annotations as the class has not been turned into
202
- # a dataclass yet
203
- cls .__annotations__ [attr_name ] = union
204
- self ._set_discriminator (cls , attr_name , getattr (cls , attr_name , None ))
205
-
206
162
207
163
_tagged_unions : dict [type , _TaggedUnion ] = {}
208
164
209
165
210
- def __init_subclass__ (discriminator : str , cls : type ):
166
+ def __init_subclass__ (tagged_union : _TaggedUnion , discriminator : str , cls : type ):
211
167
# Add a discriminator field to the class so it can
212
- # be identified when deserailizing , and make sure it is last in the list
168
+ # be identified when deserializing , and make sure it is last in the list
213
169
cls .__annotations__ = {
214
170
** cls .__annotations__ ,
215
171
discriminator : Literal [cls .__name__ ], # type: ignore
216
172
}
217
173
cls .type = Field (cls .__name__ , repr = False ) # type: ignore
218
- # Replace any bare annotation with a discriminated union of subclasses
219
- # and register this class as one that refers to that union so it can be updated
220
- for k , v in get_type_hints (cls ).items ():
221
- # This works for Expression[T] or Expression
222
- tagged_union = _tagged_unions .get (get_origin (v ) or v , None )
223
- if tagged_union :
224
- tagged_union .add_referrer (cls , k )
174
+
175
+ def __get_pydantic_core_schema__ (
176
+ cls , source_type : Any , handler : GetCoreSchemaHandler
177
+ ):
178
+ handler .generate_schema (cls )
179
+ return handler (cls )
180
+
181
+ cls .__get_pydantic_core_schema__ = classmethod (__get_pydantic_core_schema__ )
182
+ tagged_union .add_member (cls )
225
183
226
184
227
185
def __get_pydantic_core_schema__ (
@@ -230,8 +188,11 @@ def __get_pydantic_core_schema__(
230
188
# Rebuild any dataclass (including this one) that references this union
231
189
# Note that this has to be done after the creation of the dataclass so that
232
190
# previously created classes can refer to this newly created class
233
- tagged_union .add_member (cls )
234
- return handler (source_type )
191
+ # return handler(tagged_union._make_union())
192
+ return tagged_union_schema (
193
+ {member .__name__ : handler (member ) for member in tagged_union ._members },
194
+ tagged_union ._discriminator ,
195
+ )
235
196
236
197
237
198
def if_instance_do (x : Any , cls : type , func : Callable ):
0 commit comments