Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions dace/codegen/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,8 @@ def register_array_dispatcher(self, storage_type: dtypes.StorageType, func: targ
self.register_array_dispatcher(stype, func)
return

if not isinstance(storage_type, dtypes.StorageType): raise TypeError
if not (isinstance(storage_type, dtypes.StorageType) or isinstance(storage_type, dtypes.CheckableEnumMeta._Proxy)):
raise TypeError
if not isinstance(func, target.TargetCodeGenerator): raise TypeError
self._array_dispatchers[storage_type] = func

Expand Down Expand Up @@ -345,8 +346,8 @@ def register_copy_dispatcher(self,
:see: TargetCodeGenerator
"""

if not isinstance(src_storage, dtypes.StorageType): raise TypeError
if not isinstance(dst_storage, dtypes.StorageType): raise TypeError
if not (isinstance(src_storage, dtypes.StorageType) or isinstance(src_storage, dtypes.CheckableEnumMeta._Proxy)): raise TypeError
if not (isinstance(dst_storage, dtypes.StorageType) or isinstance(dst_storage, dtypes.CheckableEnumMeta._Proxy)): raise TypeError
if (dst_schedule is not None and not isinstance(dst_schedule, dtypes.ScheduleType)):
raise TypeError
if not isinstance(func, target.TargetCodeGenerator): raise TypeError
Expand Down
98 changes: 84 additions & 14 deletions dace/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,68 @@
from typing import Any, Dict, TYPE_CHECKING
from dace.config import Config
from dace.registry import extensible_enum, undefined_safe_enum
from enum import Enum, EnumMeta, auto
from dataclasses import dataclass

class CheckableEnumMeta(EnumMeta):
class _Proxy:
__slots__ = ("_member", "name", "_obj", "_args", "_kwargs")

def __init__(self, member, obj, args, kwargs):
self._member = member
self.name = obj.__class__.__name__
self._obj = obj
self._args = args
self._kwargs = kwargs

def __getattr__(self, name):
# Delegate attribute lookups to the real object
return getattr(self._obj, name)

def __deepcopy__(self, memo):
# Avoid infinite recursion by copying only what is necessary
from copy import deepcopy
# Optionally: if _obj is immutable or doesn't need deep copy, just reference it
obj_copy = deepcopy(self._obj, memo)
args_copy = deepcopy(self._args, memo)
kwargs_copy = deepcopy(self._kwargs, memo)
return type(self)(self._member, obj_copy, args_copy, kwargs_copy)

def __eq__(self, other):
# proxy == enum‐member?
if other is self._member:
return True
# proxy == proxy?
if isinstance(other, CheckableEnumMeta._Proxy):
return self._member is other._member and self._args == other._args and self._kwargs == other._kwargs
return False

def __hash__(self):
return hash((id(self._member), self._args, frozenset(self._kwargs.items())))

def __repr__(self):
return f"{self._member.name}{self._args}{self._kwargs}"

def __new__(mcs, clsname, bases, classdict):
cls = super().__new__(mcs, clsname, bases, classdict)

# 1) Make Enum members callable, returning our Proxy
def __call__(self, *args, **kwargs):
real_obj = self.value(*args, **kwargs)
return CheckableEnumMeta._Proxy(self, real_obj, args, kwargs)

cls.__call__ = __call__

# 2) Make Enum‐member == Proxy return True when appropriate
def __eq__(self, other):
if isinstance(other, CheckableEnumMeta._Proxy):
return self is other._member
return super(cls, self).__eq__(other)

cls.__eq__ = __eq__

return cls


if TYPE_CHECKING:
import enum
Expand All @@ -28,22 +90,30 @@ class DeviceType(AutoNumberEnum):
Snitch = () #: Compute Cluster (RISC-V)


@undefined_safe_enum
@extensible_enum
class StorageType(AutoNumberEnum):
class StorageType(Enum, metaclass=CheckableEnumMeta):
""" Available data storage types in the SDFG. """

Default = () #: Scope-default storage location
Register = () #: Local data on registers, stack, or equivalent memory
CPU_Pinned = () #: Host memory that can be DMA-accessed from accelerators
CPU_Heap = () #: Host memory allocated on heap
CPU_ThreadLocal = () #: Thread-local host memory
GPU_Global = () #: GPU global memory
GPU_Shared = () #: On-GPU shared memory
SVE_Register = () #: SVE register
Snitch_TCDM = () #: Cluster-private memory
Snitch_L2 = () #: External memory
Snitch_SSR = () #: Memory accessed by SSR streamer
Default = auto() #: Scope-default storage location
Register = auto() #: Local data on registers, stack, or equivalent memory
CPU_Pinned = auto() #: Host memory that can be DMA-accessed from accelerators
CPU_Heap = auto() #: Host memory allocated on heap
CPU_ThreadLocal = auto() #: Thread-local host memory
GPU_Global = auto() #: GPU global memory
GPU_Shared = auto() #: On-GPU shared memory
SVE_Register = auto() #: SVE register
Snitch_TCDM = auto() #: Cluster-private memory
Snitch_L2 = auto() #: External memory
Snitch_SSR = auto() #: Memory accessed by SSR streamer

@dataclass
class GPU_WaveMatrix:
op_id: tuple[int]
swizzled: bool
accvgpr: bool

@dataclass
class GPU_Vgpr:
accvgpr: bool


@undefined_safe_enum
Expand Down
9 changes: 8 additions & 1 deletion dace/properties.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,14 @@ def __set__(self, obj, val):
val = int(val)

# Check if type matches before setting
if (self.dtype is not None and not isinstance(val, self.dtype) and not (val is None and self.allow_none)):

if hasattr(self.dtype, "_Proxy"):
accepted_types = (self.dtype, self.dtype._Proxy)
else:
accepted_types = (self.dtype,)

if (self.dtype is not None and not isinstance(val, accepted_types) and not (val is None and self.allow_none)):

if isinstance(val, str):
raise TypeError("Received str for property {} of type {}. Use "
"from_string method of the property.".format(self.attr_name, self.dtype))
Expand Down
Loading