Skip to content

Commit

Permalink
Merge pull request #140 from bluesky/remove-uses-tagged-union
Browse files Browse the repository at this point in the history
Remove the need for uses_tagged_union deco
  • Loading branch information
coretl authored Aug 30, 2024
2 parents 0f7329f + 9e6e436 commit f5251e8
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 54 deletions.
38 changes: 3 additions & 35 deletions src/scanspec/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,7 @@
from collections.abc import Callable, Iterable, Iterator, Sequence
from functools import lru_cache
from inspect import isclass
from typing import (
Any,
Generic,
Literal,
TypeVar,
get_origin,
get_type_hints,
)
from typing import Any, Generic, Literal, TypeVar

import numpy as np
from pydantic import BaseModel, ConfigDict, Field, GetCoreSchemaHandler
Expand Down Expand Up @@ -140,24 +133,6 @@ def get_schema_of_union(cls, source_type: Any, handler: GetCoreSchemaHandler):
return super_cls


def uses_tagged_union(cls_or_func: T) -> T:
"""
T = TypeVar("T", type, Callable)
Decorator that processes the type hints of a class or function to detect and
register any tagged unions. If a tagged union is detected in the type hints,
it registers the class or function as a referrer to that tagged union.
Args:
cls_or_func (T): The class or function to be processed for tagged unions.
Returns:
T: The original class or function, unmodified.
"""
for v in get_type_hints(cls_or_func).values():
tagged_union = _tagged_unions.get(get_origin(v) or v, None)
if tagged_union:
tagged_union.add_reference(cls_or_func)
return cls_or_func


_tagged_unions: dict[type, _TaggedUnion] = {}


Expand All @@ -168,7 +143,6 @@ def __init__(self, base_class: type, discriminator: str):
self._discriminator = discriminator
# The members of the tagged union, i.e. subclasses of the baseclass
self._subclasses: list[type] = []
self._references: set[type | Callable] = set()

def add_member(self, cls: type):
if cls in self._subclasses:
Expand All @@ -177,14 +151,8 @@ def add_member(self, cls: type):
for member in self._subclasses:
if member is not cls:
_TaggedUnion._rebuild(member)
for ref in self._references:
_TaggedUnion._rebuild(ref)

def add_reference(self, cls_or_func: type | Callable):
self._references.add(cls_or_func)

@staticmethod
# https://github.com/bluesky/scanspec/issues/133
def _rebuild(cls_or_func: type | Callable):
if isclass(cls_or_func):
if is_pydantic_dataclass(cls_or_func):
Expand All @@ -194,14 +162,14 @@ def _rebuild(cls_or_func: type | Callable):

def schema(self, handler: GetCoreSchemaHandler) -> CoreSchema:
return tagged_union_schema(
make_schema(tuple(self._subclasses), handler),
_make_schema(tuple(self._subclasses), handler),
discriminator=self._discriminator,
ref=self._base_class.__name__,
)


@lru_cache(1)
def make_schema(members: tuple[type, ...], handler):
def _make_schema(members: tuple[type, ...], handler):
return {member.__name__: handler(member) for member in members}


Expand Down
4 changes: 1 addition & 3 deletions src/scanspec/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from pydantic import Field
from pydantic.dataclasses import dataclass

from scanspec.core import AxesPoints, Frames, Path, uses_tagged_union
from scanspec.core import AxesPoints, Frames, Path

from .specs import Line, Spec

Expand All @@ -25,7 +25,6 @@
Points = str | list[float]


@uses_tagged_union
@dataclass
class ValidResponse:
"""Response model for spec validation."""
Expand All @@ -42,7 +41,6 @@ class PointsFormat(str, Enum):
BASE64_ENCODED = "BASE64_ENCODED"


@uses_tagged_union
@dataclass
class PointsRequest:
"""A request for generated scan points."""
Expand Down
16 changes: 0 additions & 16 deletions tests/test_basemodel.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
import pytest
from pydantic import BaseModel, TypeAdapter
from pydantic.dataclasses import dataclass

from scanspec.core import StrictConfig, uses_tagged_union
from scanspec.specs import Line, Spec


@uses_tagged_union
class Foo(BaseModel):
spec: Spec

Expand Down Expand Up @@ -41,16 +38,3 @@ def test_type_adapter(model: Foo):
as_json = model.model_dump_json()
deserialized = type_adapter.validate_json(as_json)
assert deserialized == model


def test_schema_updates_with_new_values():
old_schema = TypeAdapter(Foo).json_schema()

@dataclass(config=StrictConfig)
class Splat(Spec[str]): # NOSONAR
def axes(self) -> list[str]:
return ["*"]

new_schema = TypeAdapter(Foo).json_schema()

assert new_schema != old_schema

0 comments on commit f5251e8

Please sign in to comment.