|
1 | 1 | from __future__ import annotations
|
2 | 2 |
|
3 |
| -import dataclasses |
4 | 3 | from collections.abc import Callable, Iterable, Iterator, Sequence
|
5 |
| -from functools import partial |
6 |
| -from inspect import isclass |
| 4 | +from functools import lru_cache |
7 | 5 | from typing import (
|
8 | 6 | Any,
|
9 | 7 | Generic,
|
10 | 8 | Literal,
|
11 | 9 | TypeVar,
|
12 |
| - Union, |
13 |
| - get_origin, |
14 |
| - get_type_hints, |
15 | 10 | )
|
16 | 11 |
|
17 | 12 | import numpy as np
|
18 | 13 | from pydantic import (
|
19 | 14 | ConfigDict,
|
20 | 15 | Field,
|
21 | 16 | GetCoreSchemaHandler,
|
22 |
| - TypeAdapter, |
23 | 17 | )
|
24 | 18 | from pydantic.dataclasses import rebuild_dataclass
|
25 |
| -from pydantic.fields import FieldInfo |
| 19 | +from pydantic_core.core_schema import tagged_union_schema |
26 | 20 |
|
27 | 21 | __all__ = [
|
28 | 22 | "if_instance_do",
|
|
43 | 37 |
|
44 | 38 |
|
45 | 39 | def discriminated_union_of_subclasses(
|
46 |
| - cls, |
| 40 | + cls: type, |
47 | 41 | discriminator: str = "type",
|
48 |
| -): |
| 42 | +) -> type: |
49 | 43 | """Add all subclasses of super_cls to a discriminated union.
|
50 | 44 |
|
51 | 45 | For all subclasses of super_cls, add a discriminator field to identify
|
52 |
| - the type. Raw JSON should look like {"type": <type name>, params for |
| 46 | + the type. Raw JSON should look like {<discriminator>: <type name>, params for |
53 | 47 | <type name>...}.
|
54 | 48 |
|
55 | 49 | Example::
|
@@ -107,131 +101,69 @@ def calculate(self) -> int:
|
107 | 101 | super_cls: The superclass of the union, Expression in the above example
|
108 | 102 | discriminator: The discriminator that will be inserted into the
|
109 | 103 | serialized documents for type determination. Defaults to "type".
|
110 |
| - config: A pydantic config class to be inserted into all |
111 |
| - subclasses. Defaults to None. |
112 | 104 |
|
113 | 105 | Returns:
|
114 |
| - Type | Callable[[Type], Type]: A decorator that adds the necessary |
| 106 | + Type: A decorator that adds the necessary |
115 | 107 | functionality to a class.
|
116 | 108 | """
|
117 | 109 | tagged_union = _TaggedUnion(cls, discriminator)
|
118 |
| - _tagged_unions[cls] = tagged_union |
119 |
| - cls.__init_subclass__ = classmethod(partial(__init_subclass__, discriminator)) |
120 |
| - cls.__get_pydantic_core_schema__ = classmethod( |
121 |
| - partial(__get_pydantic_core_schema__, tagged_union=tagged_union) |
122 |
| - ) |
123 |
| - return cls |
124 | 110 |
|
| 111 | + def add_subclass_to_union(subclass): |
| 112 | + # Add a discriminator field to a subclass so it can |
| 113 | + # be identified when deserializing |
| 114 | + subclass.__annotations__ = { |
| 115 | + **subclass.__annotations__, |
| 116 | + discriminator: Literal[subclass.__name__], # type: ignore |
| 117 | + } |
| 118 | + setattr(subclass, discriminator, Field(subclass.__name__, repr=False)) # type: ignore |
125 | 119 |
|
126 |
| -T = TypeVar("T", type, Callable) |
| 120 | + def default_handler(subclass, source_type: Any, handler: GetCoreSchemaHandler): |
| 121 | + tagged_union.add_member(subclass) |
| 122 | + return handler(subclass) |
127 | 123 |
|
| 124 | + subclass.__get_pydantic_core_schema__ = classmethod(default_handler) |
128 | 125 |
|
129 |
| -def deserialize_as(cls, obj): |
130 |
| - return _tagged_unions[cls].type_adapter.validate_python(obj) |
| 126 | + def get_schema_of_union(cls, source_type: Any, handler: GetCoreSchemaHandler): |
| 127 | + # Rebuild any dataclass (including this one) that references this union |
| 128 | + # Note that this has to be done after the creation of the dataclass so that |
| 129 | + # previously created classes can refer to this newly created class |
| 130 | + return tagged_union.schema(handler) |
131 | 131 |
|
| 132 | + cls.__init_subclass__ = classmethod(add_subclass_to_union) |
| 133 | + cls.__get_pydantic_core_schema__ = classmethod(get_schema_of_union) |
| 134 | + return cls |
132 | 135 |
|
133 |
| -def uses_tagged_union(cls_or_func: T) -> T: |
134 |
| - """ |
135 |
| - Decorator that processes the type hints of a class or function to detect and |
136 |
| - register any tagged unions. If a tagged union is detected in the type hints, |
137 |
| - it registers the class or function as a referrer to that tagged union. |
138 |
| - Args: |
139 |
| - cls_or_func (T): The class or function to be processed for tagged unions. |
140 |
| - Returns: |
141 |
| - T: The original class or function, unmodified. |
142 |
| - """ |
143 |
| - for k, v in get_type_hints(cls_or_func).items(): |
144 |
| - tagged_union = _tagged_unions.get(get_origin(v) or v, None) |
145 |
| - if tagged_union: |
146 |
| - tagged_union.add_referrer(cls_or_func, k) |
147 |
| - return cls_or_func |
| 136 | + |
| 137 | +T = TypeVar("T", type, Callable) |
148 | 138 |
|
149 | 139 |
|
150 | 140 | class _TaggedUnion:
|
151 | 141 | def __init__(self, base_class: type, discriminator: str):
|
152 | 142 | self._base_class = base_class
|
153 |
| - # The members of the tagged union, i.e. subclasses of the baseclasses |
154 |
| - self._members: list[type] = [] |
155 | 143 | # Classes and their field names that refer to this tagged union
|
156 |
| - self._referrers: dict[type | Callable, set[str]] = {} |
157 |
| - self.type_adapter: TypeAdapter = TypeAdapter(None) |
158 | 144 | self._discriminator = discriminator
|
159 |
| - |
160 |
| - def _make_union(self): |
161 |
| - if len(self._members) > 0: |
162 |
| - return Union[tuple(self._members)] # type: ignore # noqa |
163 |
| - |
164 |
| - def _set_discriminator(self, cls: type | Callable, field_name: str, field: Any): |
165 |
| - # Set the field to use the `type` discriminator on deserialize |
166 |
| - # https://docs.pydantic.dev/2.8/concepts/unions/#discriminated-unions-with-str-discriminators |
167 |
| - if isclass(cls): |
168 |
| - assert isinstance( |
169 |
| - field, FieldInfo |
170 |
| - ), f"Expected {cls.__name__}.{field_name} to be a Pydantic field, not {field!r}" # noqa: E501 |
171 |
| - field.discriminator = self._discriminator |
| 145 | + # The members of the tagged union, i.e. subclasses of the baseclass |
| 146 | + self._members: list[type] = [] |
172 | 147 |
|
173 | 148 | def add_member(self, cls: type):
|
174 | 149 | if cls in self._members:
|
175 |
| - # A side effect of hooking to __get_pydantic_core_schema__ is that it is |
176 |
| - # called muliple times for the same member, do no process if it wouldn't |
177 |
| - # change the member list |
178 | 150 | return
|
179 |
| - |
180 | 151 | self._members.append(cls)
|
181 |
| - union = self._make_union() |
182 |
| - if union: |
183 |
| - # There are more than 1 subclasses in the union, so set all the referrers |
184 |
| - # to use this union |
185 |
| - for referrer, fields in self._referrers.items(): |
186 |
| - if isclass(referrer): |
187 |
| - for field in dataclasses.fields(referrer): |
188 |
| - if field.name in fields: |
189 |
| - field.type = union |
190 |
| - self._set_discriminator(referrer, field.name, field.default) |
191 |
| - rebuild_dataclass(referrer, force=True) |
192 |
| - # Make a type adapter for use in deserialization |
193 |
| - self.type_adapter = TypeAdapter(union) |
194 |
| - |
195 |
| - def add_referrer(self, cls: type | Callable, attr_name: str): |
196 |
| - self._referrers.setdefault(cls, set()).add(attr_name) |
197 |
| - union = self._make_union() |
198 |
| - if union: |
199 |
| - # There are more than 1 subclasses in the union, so set the referrer |
200 |
| - # (which is currently being constructed) to use it |
201 |
| - # note that we use annotations as the class has not been turned into |
202 |
| - # a dataclass yet |
203 |
| - cls.__annotations__[attr_name] = union |
204 |
| - self._set_discriminator(cls, attr_name, getattr(cls, attr_name, None)) |
205 |
| - |
206 |
| - |
207 |
| -_tagged_unions: dict[type, _TaggedUnion] = {} |
208 |
| - |
209 |
| - |
210 |
| -def __init_subclass__(discriminator: str, cls: type): |
211 |
| - # Add a discriminator field to the class so it can |
212 |
| - # be identified when deserailizing, and make sure it is last in the list |
213 |
| - cls.__annotations__ = { |
214 |
| - **cls.__annotations__, |
215 |
| - discriminator: Literal[cls.__name__], # type: ignore |
216 |
| - } |
217 |
| - cls.type = Field(cls.__name__, repr=False) # type: ignore |
218 |
| - # Replace any bare annotation with a discriminated union of subclasses |
219 |
| - # and register this class as one that refers to that union so it can be updated |
220 |
| - for k, v in get_type_hints(cls).items(): |
221 |
| - # This works for Expression[T] or Expression |
222 |
| - tagged_union = _tagged_unions.get(get_origin(v) or v, None) |
223 |
| - if tagged_union: |
224 |
| - tagged_union.add_referrer(cls, k) |
225 |
| - |
226 |
| - |
227 |
| -def __get_pydantic_core_schema__( |
228 |
| - cls, source_type: Any, handler: GetCoreSchemaHandler, tagged_union: _TaggedUnion |
229 |
| -): |
230 |
| - # Rebuild any dataclass (including this one) that references this union |
231 |
| - # Note that this has to be done after the creation of the dataclass so that |
232 |
| - # previously created classes can refer to this newly created class |
233 |
| - tagged_union.add_member(cls) |
234 |
| - return handler(source_type) |
| 152 | + for member in self._members: |
| 153 | + if member != cls: |
| 154 | + rebuild_dataclass(member, force=True) |
| 155 | + |
| 156 | + def schema(self, handler): |
| 157 | + return tagged_union_schema( |
| 158 | + make_schema(tuple(self._members), handler), |
| 159 | + discriminator=self._discriminator, |
| 160 | + ref=self._base_class.__name__, |
| 161 | + ) |
| 162 | + |
| 163 | + |
| 164 | +@lru_cache(1) |
| 165 | +def make_schema(members: tuple[type, ...], handler): |
| 166 | + return {member.__name__: handler(member) for member in members} |
235 | 167 |
|
236 | 168 |
|
237 | 169 | def if_instance_do(x: Any, cls: type, func: Callable):
|
|
0 commit comments