From 939674800144f8f7420ffff99830bbb4beba9f95 Mon Sep 17 00:00:00 2001 From: Alexander Maeder Date: Fri, 30 Jan 2026 01:01:09 -0800 Subject: [PATCH] Added initial design for storage types with attributes. Added example storage types for AMD GPU vector registers. --- dace/codegen/dispatcher.py | 7 +-- dace/dtypes.py | 98 ++++++++++++++++++++++++++++++++------ dace/properties.py | 9 +++- 3 files changed, 96 insertions(+), 18 deletions(-) diff --git a/dace/codegen/dispatcher.py b/dace/codegen/dispatcher.py index ff15f110bb..f608e055b0 100644 --- a/dace/codegen/dispatcher.py +++ b/dace/codegen/dispatcher.py @@ -315,7 +315,8 @@ def register_array_dispatcher(self, storage_type: dtypes.StorageType, func: targ self.register_array_dispatcher(stype, func) return - if not isinstance(storage_type, dtypes.StorageType): raise TypeError + if not (isinstance(storage_type, dtypes.StorageType) or isinstance(storage_type, dtypes.CheckableEnumMeta._Proxy)): + raise TypeError if not isinstance(func, target.TargetCodeGenerator): raise TypeError self._array_dispatchers[storage_type] = func @@ -345,8 +346,8 @@ def register_copy_dispatcher(self, :see: TargetCodeGenerator """ - if not isinstance(src_storage, dtypes.StorageType): raise TypeError - if not isinstance(dst_storage, dtypes.StorageType): raise TypeError + if not (isinstance(src_storage, dtypes.StorageType) or isinstance(src_storage, dtypes.CheckableEnumMeta._Proxy)): raise TypeError + if not (isinstance(dst_storage, dtypes.StorageType) or isinstance(dst_storage, dtypes.CheckableEnumMeta._Proxy)): raise TypeError if (dst_schedule is not None and not isinstance(dst_schedule, dtypes.ScheduleType)): raise TypeError if not isinstance(func, target.TargetCodeGenerator): raise TypeError diff --git a/dace/dtypes.py b/dace/dtypes.py index eb49383ad3..563ba95ca2 100644 --- a/dace/dtypes.py +++ b/dace/dtypes.py @@ -12,6 +12,68 @@ from typing import Any, Dict, TYPE_CHECKING from dace.config import Config from dace.registry import extensible_enum, undefined_safe_enum +from enum import Enum, EnumMeta, auto +from dataclasses import dataclass + +class CheckableEnumMeta(EnumMeta): + class _Proxy: + __slots__ = ("_member", "name", "_obj", "_args", "_kwargs") + + def __init__(self, member, obj, args, kwargs): + self._member = member + self.name = obj.__class__.__name__ + self._obj = obj + self._args = args + self._kwargs = kwargs + + def __getattr__(self, name): + # Delegate attribute lookups to the real object + return getattr(self._obj, name) + + def __deepcopy__(self, memo): + # Avoid infinite recursion by copying only what is necessary + from copy import deepcopy + # Optionally: if _obj is immutable or doesn't need deep copy, just reference it + obj_copy = deepcopy(self._obj, memo) + args_copy = deepcopy(self._args, memo) + kwargs_copy = deepcopy(self._kwargs, memo) + return type(self)(self._member, obj_copy, args_copy, kwargs_copy) + + def __eq__(self, other): + # proxy == enum‐member? + if other is self._member: + return True + # proxy == proxy? + if isinstance(other, CheckableEnumMeta._Proxy): + return self._member is other._member and self._args == other._args and self._kwargs == other._kwargs + return False + + def __hash__(self): + return hash((id(self._member), self._args, frozenset(self._kwargs.items()))) + + def __repr__(self): + return f"{self._member.name}{self._args}{self._kwargs}" + + def __new__(mcs, clsname, bases, classdict): + cls = super().__new__(mcs, clsname, bases, classdict) + + # 1) Make Enum members callable, returning our Proxy + def __call__(self, *args, **kwargs): + real_obj = self.value(*args, **kwargs) + return CheckableEnumMeta._Proxy(self, real_obj, args, kwargs) + + cls.__call__ = __call__ + + # 2) Make Enum‐member == Proxy return True when appropriate + def __eq__(self, other): + if isinstance(other, CheckableEnumMeta._Proxy): + return self is other._member + return super(cls, self).__eq__(other) + + cls.__eq__ = __eq__ + + return cls + if TYPE_CHECKING: import enum @@ -28,22 +90,30 @@ class DeviceType(AutoNumberEnum): Snitch = () #: Compute Cluster (RISC-V) -@undefined_safe_enum -@extensible_enum -class StorageType(AutoNumberEnum): +class StorageType(Enum, metaclass=CheckableEnumMeta): """ Available data storage types in the SDFG. """ - Default = () #: Scope-default storage location - Register = () #: Local data on registers, stack, or equivalent memory - CPU_Pinned = () #: Host memory that can be DMA-accessed from accelerators - CPU_Heap = () #: Host memory allocated on heap - CPU_ThreadLocal = () #: Thread-local host memory - GPU_Global = () #: GPU global memory - GPU_Shared = () #: On-GPU shared memory - SVE_Register = () #: SVE register - Snitch_TCDM = () #: Cluster-private memory - Snitch_L2 = () #: External memory - Snitch_SSR = () #: Memory accessed by SSR streamer + Default = auto() #: Scope-default storage location + Register = auto() #: Local data on registers, stack, or equivalent memory + CPU_Pinned = auto() #: Host memory that can be DMA-accessed from accelerators + CPU_Heap = auto() #: Host memory allocated on heap + CPU_ThreadLocal = auto() #: Thread-local host memory + GPU_Global = auto() #: GPU global memory + GPU_Shared = auto() #: On-GPU shared memory + SVE_Register = auto() #: SVE register + Snitch_TCDM = auto() #: Cluster-private memory + Snitch_L2 = auto() #: External memory + Snitch_SSR = auto() #: Memory accessed by SSR streamer + + @dataclass + class GPU_WaveMatrix: + op_id: tuple[int] + swizzled: bool + accvgpr: bool + + @dataclass + class GPU_Vgpr: + accvgpr: bool @undefined_safe_enum diff --git a/dace/properties.py b/dace/properties.py index 5137abc621..007b94ff9d 100644 --- a/dace/properties.py +++ b/dace/properties.py @@ -171,7 +171,14 @@ def __set__(self, obj, val): val = int(val) # Check if type matches before setting - if (self.dtype is not None and not isinstance(val, self.dtype) and not (val is None and self.allow_none)): + + if hasattr(self.dtype, "_Proxy"): + accepted_types = (self.dtype, self.dtype._Proxy) + else: + accepted_types = (self.dtype,) + + if (self.dtype is not None and not isinstance(val, accepted_types) and not (val is None and self.allow_none)): + if isinstance(val, str): raise TypeError("Received str for property {} of type {}. Use " "from_string method of the property.".format(self.attr_name, self.dtype))