Skip to content

Commit cc3c2d3

Browse files
committed
feat(annots): support marking classes as singletons
1 parent 9fc7aee commit cc3c2d3

File tree

4 files changed

+158
-36
lines changed

4 files changed

+158
-36
lines changed

koerce/annots.py

Lines changed: 64 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from collections.abc import Mapping, Sequence
77
from types import FunctionType, MethodType
88
from typing import Any, ClassVar, Optional
9+
from weakref import WeakValueDictionary
910

1011
import cython
1112

@@ -541,7 +542,7 @@ def varkwargs(pattern=_any, typehint=EMPTY):
541542
return Parameter(kind=_VAR_KEYWORD, pattern=pattern, typehint=typehint)
542543

543544

544-
__create__ = cython.declare(object, type.__call__)
545+
__type_call__ = cython.declare(object, type.__call__)
545546
if cython.compiled:
546547
from cython.cimports.cpython.object import PyObject_GenericSetAttr as __setattr__
547548
else:
@@ -555,6 +556,7 @@ class AnnotableSpec:
555556
initable = cython.declare(cython.bint, visibility="readonly")
556557
hashable = cython.declare(cython.bint, visibility="readonly")
557558
immutable = cython.declare(cython.bint, visibility="readonly")
559+
singleton = cython.declare(cython.bint, visibility="readonly")
558560
signature = cython.declare(Signature, visibility="readonly")
559561
attributes = cython.declare(dict[str, Attribute], visibility="readonly")
560562
hasattribs = cython.declare(cython.bint, visibility="readonly")
@@ -564,44 +566,66 @@ def __init__(
564566
initable: bool,
565567
hashable: bool,
566568
immutable: bool,
569+
singleton: bool,
567570
signature: Signature,
568571
attributes: dict[str, Attribute],
569572
):
570573
self.initable = initable
571574
self.hashable = hashable
572575
self.immutable = immutable
576+
self.singleton = singleton
573577
self.signature = signature
574578
self.attributes = attributes
575579
self.hasattribs = bool(attributes)
576580

577581
@cython.cfunc
578582
@cython.inline
579583
def new(self, cls: type, args: tuple[Any, ...], kwargs: dict[str, Any]):
580-
ctx: dict[str, Any] = {}
581584
bound: dict[str, Any]
582-
param: Parameter
583-
584585
if not args and len(kwargs) == self.signature.length:
585586
bound = kwargs
586587
else:
587588
bound = self.signature.bind(args, kwargs)
588589

589-
if self.initable:
590-
# slow initialization calling __init__
591-
for name, param in self.signature.parameters.items():
592-
bound[name] = param.pattern.match(bound[name], ctx)
593-
return __create__(cls, **bound)
590+
if self.singleton or self.initable:
591+
return self.new_slow(cls, bound)
594592
else:
595-
# fast initialization directly setting the arguments
596-
this = cls.__new__(cls)
597-
for name, param in self.signature.parameters.items():
598-
__setattr__(this, name, param.pattern.match(bound[name], ctx))
599-
# TODO(kszucs): test order ot precomputes and attributes calculations
600-
if self.hashable:
601-
self.init_precomputes(this)
602-
if self.hasattribs:
603-
self.init_attributes(this)
604-
return this
593+
return self.new_fast(cls, bound)
594+
595+
@cython.cfunc
596+
@cython.inline
597+
def new_slow(self, cls: type, bound: dict[str, Any]):
598+
# slow initialization calling __init__
599+
ctx: dict[str, Any] = {}
600+
param: Parameter
601+
for name, param in self.signature.parameters.items():
602+
bound[name] = param.pattern.match(bound[name], ctx)
603+
604+
if self.singleton:
605+
key = (cls, *bound.items())
606+
try:
607+
return cls.__instances__[key]
608+
except KeyError:
609+
this = __type_call__(cls, **bound)
610+
cls.__instances__[key] = this
611+
return this
612+
613+
return __type_call__(cls, **bound)
614+
615+
@cython.cfunc
616+
@cython.inline
617+
def new_fast(self, cls: type, bound: dict[str, Any]):
618+
# fast initialization directly setting the arguments
619+
ctx: dict[str, Any] = {}
620+
param: Parameter
621+
this = cls.__new__(cls)
622+
for name, param in self.signature.parameters.items():
623+
__setattr__(this, name, param.pattern.match(bound[name], ctx))
624+
if self.hashable:
625+
self.init_precomputes(this)
626+
if self.hasattribs:
627+
self.init_attributes(this)
628+
return this
605629

606630
@cython.cfunc
607631
@cython.inline
@@ -627,8 +651,7 @@ def init_precomputes(self, this) -> cython.void:
627651
class AbstractMeta(type):
628652
"""Base metaclass for many of the ibis core classes.
629653
630-
Enforce the subclasses to define a `__slots__` attribute and provide a
631-
`__create__` classmethod to change the instantiation behavior of the class.
654+
Enforce the subclasses to define a `__slots__` attribute.
632655
633656
Support abstract methods without extending `abc.ABCMeta`. While it provides
634657
a reduced feature set compared to `abc.ABCMeta` (no way to register virtual
@@ -639,8 +662,8 @@ class AbstractMeta(type):
639662
__slots__ = ()
640663

641664
def __new__(metacls, clsname, bases, dct, **kwargs):
642-
# # enforce slot definitions
643-
# dct.setdefault("__slots__", ())
665+
# enforce slot definitions
666+
dct.setdefault("__slots__", ())
644667

645668
# construct the class object
646669
cls = super().__new__(metacls, clsname, bases, dct, **kwargs)
@@ -663,6 +686,10 @@ def __new__(metacls, clsname, bases, dct, **kwargs):
663686
return cls
664687

665688

689+
class Abstract(metaclass=AbstractMeta):
690+
"""Base class for many of the ibis core classes, see `AbstractMeta`."""
691+
692+
666693
class AnnotableMeta(AbstractMeta):
667694
def __new__(
668695
metacls,
@@ -672,6 +699,7 @@ def __new__(
672699
initable=None,
673700
hashable=None,
674701
immutable=None,
702+
singleton=False,
675703
allow_coercion=True,
676704
**kwargs,
677705
):
@@ -682,6 +710,7 @@ def __new__(
682710
is_initable: cython.bint
683711
is_hashable: cython.bint = hashable is True
684712
is_immutable: cython.bint = immutable is True
713+
is_singleton: cython.bint = singleton is True
685714
if initable is None:
686715
is_initable = "__init__" in dct or "__new__" in dct
687716
else:
@@ -713,6 +742,8 @@ def __new__(
713742
traits.append(Hashable)
714743
if immutable:
715744
traits.append(Immutable)
745+
if singleton:
746+
traits.append(Singleton)
716747

717748
# collect type annotations and convert them to patterns
718749
slots: list[str] = list(dct.pop("__slots__", []))
@@ -757,6 +788,7 @@ def __new__(
757788
spec = AnnotableSpec(
758789
initable=is_initable,
759790
hashable=is_hashable,
791+
singleton=is_singleton,
760792
immutable=is_immutable,
761793
signature=signature,
762794
attributes=attributes,
@@ -778,9 +810,14 @@ def __call__(cls, *args, **kwargs):
778810
return spec.new(cython.cast(type, cls), args, kwargs)
779811

780812

781-
class Immutable:
782-
__slots__ = ()
813+
class Singleton(Abstract):
814+
"""Cache instances of the class based on instantiation arguments."""
815+
816+
__instances__: Mapping[Any, Self] = WeakValueDictionary()
817+
__slots__ = ("__weakref__",)
783818

819+
820+
class Immutable(Abstract):
784821
def __copy__(self):
785822
return self
786823

@@ -794,7 +831,7 @@ def __setattr__(self, name: str, _: Any) -> None:
794831
)
795832

796833

797-
class Hashable:
834+
class Hashable(Abstract):
798835
__slots__ = ("__args__", "__precomputed_hash__")
799836

800837
def __hash__(self) -> int:
@@ -809,13 +846,11 @@ def __eq__(self, other) -> bool:
809846
)
810847

811848

812-
class Annotable(metaclass=AnnotableMeta, initable=False):
849+
class Annotable(Abstract, metaclass=AnnotableMeta, initable=False):
813850
__argnames__: ClassVar[tuple[str, ...]]
814851
__match_args__: ClassVar[tuple[str, ...]]
815852
__signature__: ClassVar[Signature]
816853

817-
__slots__ = ("__weakref__",)
818-
819854
def __init__(self, **kwargs):
820855
spec: AnnotableSpec = self.__spec__
821856
for name, value in kwargs.items():

koerce/tests/test_annots.py

Lines changed: 78 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,9 @@
2020
import pytest
2121
from typing_extensions import Self
2222

23-
from koerce._internal import (
23+
from koerce import (
2424
EMPTY,
25+
Abstract,
2526
AbstractMeta,
2627
Annotable,
2728
AnnotableMeta,
@@ -1269,7 +1270,7 @@ class AnnImm(Annotable, immutable=True):
12691270
lower = optional(is_int, default=0)
12701271
upper = optional(is_int, default=None)
12711272

1272-
assert AnnImm.__mro__ == (AnnImm, Immutable, Annotable, object)
1273+
assert AnnImm.__mro__ == (AnnImm, Immutable, Annotable, Abstract, object)
12731274

12741275
obj = AnnImm(3, lower=0, upper=4)
12751276
with pytest.raises(AttributeError):
@@ -1851,6 +1852,7 @@ def test_hashable():
18511852
Hashable,
18521853
Immutable,
18531854
Annotable,
1855+
Abstract,
18541856
object,
18551857
)
18561858

@@ -1878,10 +1880,6 @@ def test_hashable():
18781880
# hashable
18791881
assert {obj: 1}.get(obj) == 1
18801882

1881-
# weakrefable
1882-
ref = weakref.ref(obj)
1883-
assert ref() == obj
1884-
18851883
# serializable
18861884
assert pickle.loads(pickle.dumps(obj)) == obj
18871885

@@ -1954,7 +1952,7 @@ class Example(Annotable):
19541952

19551953

19561954
def test_abstract_meta():
1957-
class Foo(metaclass=AbstractMeta):
1955+
class Foo(Abstract):
19581956
@abstractmethod
19591957
def foo(self): ...
19601958

@@ -2154,3 +2152,76 @@ class User(Annotable):
21542152
assert User.__spec__.initable is False
21552153
assert User.__spec__.immutable is False
21562154
assert User.__spec__.hashable is False
2155+
2156+
2157+
def test_arg_and_hash_precomputed_before_attributes():
2158+
class Frozen(Annotable, immutable=True, hashable=True):
2159+
arg: int
2160+
2161+
@attribute
2162+
def a(self):
2163+
assert self.__args__ == (1,)
2164+
assert isinstance(self.__precomputed_hash__, int)
2165+
return "ok"
2166+
2167+
assert Frozen(1).a == "ok"
2168+
2169+
2170+
class OneAndOnly(Annotable, singleton=True):
2171+
__instances__ = weakref.WeakValueDictionary()
2172+
2173+
2174+
class DataType(Annotable, singleton=True):
2175+
__instances__ = weakref.WeakValueDictionary()
2176+
nullable: bool = True
2177+
2178+
2179+
def test_singleton_basics():
2180+
one = OneAndOnly()
2181+
only = OneAndOnly()
2182+
assert one is only
2183+
2184+
assert len(OneAndOnly.__instances__) == 1
2185+
key = (OneAndOnly,)
2186+
assert OneAndOnly.__instances__[key] is one
2187+
2188+
2189+
def test_singleton_lifetime() -> None:
2190+
one = OneAndOnly()
2191+
assert len(OneAndOnly.__instances__) == 1
2192+
2193+
del one
2194+
assert len(OneAndOnly.__instances__) == 0
2195+
2196+
2197+
def test_singleton_with_argument() -> None:
2198+
dt1 = DataType(nullable=True)
2199+
dt2 = DataType(nullable=False)
2200+
dt3 = DataType(nullable=True)
2201+
2202+
assert dt1 is dt3
2203+
assert dt1 is not dt2
2204+
assert len(DataType.__instances__) == 2
2205+
2206+
del dt3
2207+
assert len(DataType.__instances__) == 2
2208+
del dt1
2209+
assert len(DataType.__instances__) == 1
2210+
del dt2
2211+
assert len(DataType.__instances__) == 0
2212+
2213+
2214+
def test_singleton_looked_after_validation() -> None:
2215+
class Single(Annotable, singleton=True):
2216+
value: As[int]
2217+
2218+
# arguments looked up after validation
2219+
obj1 = Single("1")
2220+
obj2 = Single(2)
2221+
assert Single("1") is obj1
2222+
assert Single(1) is obj1
2223+
assert Single(1.0) is obj1
2224+
assert Single(2) is obj2
2225+
assert Single("2") is obj2
2226+
assert obj2 is not obj1
2227+
assert Single("3") is Single(3.0)

koerce/tests/test_y.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,9 @@ def test_msgspec(benchmark):
204204

205205

206206
def test_annotated(benchmark):
207+
assert KUser.__spec__.initable is False
208+
assert KUser.__spec__.singleton is False
209+
207210
r2 = benchmark.pedantic(
208211
KUser,
209212
args=(),

koerce/utils.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,3 +251,16 @@ def rewind(self):
251251
def checkpoint(self):
252252
"""Create a checkpoint of the current iterator state."""
253253
self._iterator, self._checkpoint = itertools.tee(self._iterator)
254+
255+
256+
# def format_typehint(typ: Any) -> str:
257+
# if isinstance(typ, type):
258+
# return typ.__name__
259+
# elif isinstance(typ, TypeVar):
260+
# if typ.__bound__ is None:
261+
# return str(typ)
262+
# else:
263+
# return format_typehint(typ.__bound__)
264+
# else:
265+
# # remove the module name from the typehint, including generics
266+
# return re.sub(r"(\w+\.)+", "", str(typ))

0 commit comments

Comments
 (0)