Skip to content

Commit

Permalink
rng_bit_generator is unsupported
Browse files Browse the repository at this point in the history
  • Loading branch information
radenmuaz committed Jan 28, 2024
1 parent 53206df commit c5766c4
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 190 deletions.
6 changes: 6 additions & 0 deletions examples/simple/rand.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
import slope

x = slope.tensor([1,1], dtype=slope.uint64)
y = slope.rng_bits(x);print(y)
# print(slope.rand(5))
# print(slope.rand(5))
193 changes: 10 additions & 183 deletions experimental/iree_compile_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,16 @@ def create_stablehlo_module(instance):
%4 = "stablehlo.maximum"(%2, %3) : (tensor<1x10xf32>, tensor<1x10xf32>) -> tensor<1x10xf32>
"func.return"(%4): (tensor<1x10xf32>) -> ()
}
""",
"""

# '''
# func.func @main (%x0: tensor<2xui64>) -> (tensor<2xui64>)
# {
# %y_, %y = mhlo.rng_bit_generator %x, algorithm = THREE_FRY : (tensor<2xui64>) -> (tensor<2xui64>, tensor<2x2xui64>)
# "func.return"(%y0): (tensor<2xui64>) -> ()
# }
# '''
,
target_backends=iree.compiler.DEFAULT_TESTING_BACKENDS,
)
m = iree.runtime.VmModule.from_flatbuffer(instance, binary)
Expand All @@ -110,185 +119,3 @@ def create_stablehlo_module(instance):
print("result:", result.to_host())

breakpoint()

# m = create_add_scalar_module(instance)
# context = iree.runtime.VmContext(instance, modules=[hal_module, m])
# f = m.lookup_function("add_scalar")
# finv = iree.runtime.FunctionInvoker(context, device, f, tracer=None)
# result = finv(5, 6)
# logging.info("result: %s", result)

# m = create_simple_dynamic_abs_module(instance)
# context = iree.runtime.VmContext(instance, modules=[hal_module, m])
# f = m.lookup_function("dynamic_abs")
# finv = iree.runtime.FunctionInvoker(context, device, f, tracer=None)
# arg0 = np.array([[-1.0, 2.0], [3.0, -4.0]], dtype=np.float32)
# result = finv(arg0)
# logging.info("result: %s", result)
# np.testing.assert_allclose(result, [[1.0, 2.0], [3.0, 4.0]])

# m = create_simple_static_mul_module(instance)
# context = iree.runtime.VmContext(instance, modules=[hal_module, m])
# f = m.lookup_function("simple_mul")
# finv = iree.runtime.FunctionInvoker(context, device, f, tracer=None)
# arg0 = iree.runtime.asdevicearray(device, np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32))
# arg1 = iree.runtime.asdevicearray(device,np.array([4.0, 5.0, 6.0, 7.0], dtype=np.float32))
# result = finv(arg0, arg1)
# logging.info("result: %s", result)
# np.testing.assert_allclose(result, [4.0, 10.0, 18.0, 28.0])

# class DeviceHalTest(unittest.TestCase):
# def setUp(self):
# super().setUp()
# self.device = iree.runtime.get_device("local-task")
# self.allocator = self.device.allocator
# # Make sure device setup maintains proper references.
# gc.collect()

# def testGcShutdownFiasco(self):
# init_ary = np.zeros([3, 4], dtype=np.int32) + 2
# ary = iree.runtime.asdevicearray(self.device, init_ary)

# # Drop all references to backing objects in reverse order to try to
# # trigger heap use-after-free on bad shutdown order.
# self.allocator = None
# gc.collect()
# self.device = None
# gc.collect()

# # Now drop the ary and make sure nothing crashes (which would indicate
# # a reference counting problem of some kind): The array should retain
# # everything that it needs to stay live.
# ary = None
# gc.collect()

# def testMetadataAttributes(self):
# init_ary = np.zeros([3, 4], dtype=np.int32) + 2
# ary = iree.runtime.asdevicearray(self.device, init_ary)
# self.assertEqual([3, 4], ary.shape)
# self.assertEqual(np.int32, ary.dtype)

# def testExplicitHostTransfer(self):
# init_ary = np.zeros([3, 4], dtype=np.int32) + 2
# ary = iree.runtime.asdevicearray(self.device, init_ary)
# self.assertEqual(repr(ary), "<IREE DeviceArray: shape=[3, 4], dtype=int32>")
# self.assertFalse(ary.is_host_accessible)

# # Explicit transfer.
# cp = ary.to_host()
# np.testing.assert_array_equal(cp, init_ary)
# self.assertTrue(ary.is_host_accessible)

# def testOverrideDtype(self):
# init_ary = np.zeros([3, 4], dtype=np.int32) + 2
# buffer_view = self.allocator.allocate_buffer_copy(
# memory_type=iree.runtime.MemoryType.DEVICE_LOCAL,
# allowed_usage=iree.runtime.BufferUsage.DEFAULT,
# device=self.device,
# buffer=init_ary,
# element_type=iree.runtime.HalElementType.SINT_32,
# )

# ary = iree.runtime.DeviceArray(
# self.device, buffer_view, override_dtype=np.float32
# )

# # Explicit transfer.
# cp = ary.to_host()
# self.assertEqual(cp.dtype, np.float32)
# np.testing.assert_array_equal(cp, init_ary.astype(np.float32))
# self.assertTrue(ary.is_host_accessible)

# def testIllegalImplicitHostTransfer(self):
# init_ary = np.zeros([3, 4], dtype=np.int32) + 2
# ary = iree.runtime.asdevicearray(self.device, init_ary)
# # Implicit transfer.
# with self.assertRaises(ValueError):
# _ = np.asarray(ary)

# def testImplicitHostArithmetic(self):
# init_ary = np.zeros([3, 4], dtype=np.int32) + 2
# ary = iree.runtime.asdevicearray(
# self.device, init_ary, implicit_host_transfer=True
# )
# sum = ary + init_ary
# np.testing.assert_array_equal(sum, init_ary + 2)
# self.assertTrue(ary.is_host_accessible)

# def testArrayFunctions(self):
# init_ary = np.zeros([3, 4], dtype=np.float32) + 2
# ary = iree.runtime.asdevicearray(
# self.device, init_ary, implicit_host_transfer=True
# )
# f = np.isfinite(ary)
# self.assertTrue(f.all())

# def testIteration(self):
# init_ary = np.array([0, 1, 2, 3, 4, 5])
# ary = iree.runtime.asdevicearray(
# self.device, init_ary, implicit_host_transfer=True
# )

# for index, value in enumerate(ary):
# self.assertEqual(index, value)

# def testSubscriptable(self):
# init_ary = np.array([0, 1, 2, 3, 4, 5])
# ary = iree.runtime.asdevicearray(
# self.device, init_ary, implicit_host_transfer=True
# )

# for index in range(0, 6):
# value = ary[index]
# self.assertEqual(index, value)

# def testReshape(self):
# init_ary = np.zeros([3, 4], dtype=np.float32) + 2
# ary = iree.runtime.asdevicearray(
# self.device, init_ary, implicit_host_transfer=True
# )
# reshaped = ary.reshape((4, 3))
# self.assertEqual((4, 3), reshaped.shape)

# np_reshaped = np.reshape(ary, (2, 2, 3))
# self.assertEqual((2, 2, 3), np_reshaped.shape)

# def testDeepcopy(self):
# init_ary = np.zeros([3, 4], dtype=np.float32) + 2
# orig_ary = iree.runtime.asdevicearray(
# self.device, init_ary, implicit_host_transfer=True
# )
# copy_ary = copy.deepcopy(orig_ary)
# self.assertIsNot(orig_ary, copy_ary)
# np.testing.assert_array_equal(orig_ary, copy_ary)

# def testAsType(self):
# init_ary = np.zeros([3, 4], dtype=np.int32) + 2
# orig_ary = iree.runtime.asdevicearray(
# self.device, init_ary, implicit_host_transfer=True
# )
# # Same dtype, no copy.
# i32_nocopy = orig_ary.astype(np.int32, copy=False)
# self.assertIs(orig_ary, i32_nocopy)

# # Same dtype, copy.
# i32_nocopy = orig_ary.astype(np.int32)
# self.assertIsNot(orig_ary, i32_nocopy)
# np.testing.assert_array_equal(orig_ary, i32_nocopy)

# # Different dtype, copy.
# f32_copy = orig_ary.astype(np.float32)
# self.assertIsNot(orig_ary, f32_copy)
# self.assertEqual(f32_copy.dtype, np.float32)
# np.testing.assert_array_equal(orig_ary.astype(np.float32), f32_copy)

# def testBool(self):
# init_ary = np.zeros([3, 4], dtype=np.bool_)
# init_ary[1] = True # Set some non-zero value.
# ary = iree.runtime.asdevicearray(self.device, init_ary)
# self.assertEqual(repr(ary), "<IREE DeviceArray: shape=[3, 4], dtype=bool>")
# np.testing.assert_array_equal(ary.to_host(), init_ary)


# if __name__ == "__main__":
# unittest.main()
11 changes: 6 additions & 5 deletions src/slope/backends/iree.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,7 @@ def indent(code, amount):
dtypes.bool: np.dtypes.BoolDType(),
dtypes.int32: np.dtypes.Int32DType(),
dtypes.int64: np.dtypes.Int64DType(),
dtypes.uint64: np.dtypes.UInt64DType(),
dtypes.float16: np.dtypes.Float16DType(),
},
{
Expand Down Expand Up @@ -510,16 +511,17 @@ def full_impl(self, y, *, shape, fill_value, dtype, device):
def random_uniform_impl(self, y, *, shape, dtype, device):
zero = "0." if dtypes.is_float(y.symval.dtype) else "0"
one = "1." if dtypes.is_float(y.symval.dtype) else "1"
a_type = b_type = SymbolicTensor((), dtype)
a_type = b_type = y.symval.override(shape=())
is_scalar = shape == ()
shape_val = f'dense<{repr(list(shape)) if not is_scalar else "[1]"}'
shape_type = SymbolicTensor((1,) if is_scalar else (len(shape),), Tensor.int64)
y_out_type = y.symval if not is_scalar else SymbolicTensor((1,), y.symval.dtype)
shape_type = y.symval.override(shape=(1,) if is_scalar else (len(shape),), dtype=dtypes.int64)
y_out_type = y.symval if not is_scalar else y.symval.override(shape=(1,))
return f"""%{y.name}_a = stablehlo.constant dense<{zero}> : {as_mlir_shape(a_type)}
%{y.name}_b = stablehlo.constant dense<{one}> : {as_mlir_shape(b_type)}
%{y.name}_shape = stablehlo.constant {shape_val}> : {as_mlir_shape(shape_type)}
%{y.name}{'_' if is_scalar else ''} = "stablehlo.rng"(%{y.name}_a, %{y.name}_b,%{y.name}_shape) {{
rng_distribution = #stablehlo<rng_distribution UNIFORM>}} : {as_mlir_sig((a_type, b_type, shape_type), y_out_type)}
rng_distribution = #stablehlo<rng_distribution UNIFORM>
}} : {as_mlir_sig((a_type, b_type, shape_type), y_out_type)}
{f'%{y.name} = "stablehlo.reshape"(%{y.name}_) : {as_mlir_sig((y_out_type,), y.symval)}' if is_scalar else ''}"""


Expand All @@ -539,7 +541,6 @@ def random_normal_impl(self, y, *, shape, dtype, device):
rng_distribution = #stablehlo<rng_distribution NORMAL>}} : {as_mlir_sig((a_type, b_type, shape_type), y_out_type)}
{f'%{y.name} = "stablehlo.reshape"(%{y.name}_) : {as_mlir_sig((y_out_type,), y.symval)}' if is_scalar else ''}"""


@backend.set_impl(backend.operator_set.expand)
def expand_impl(self, x, y, *, shape):
return f"""%{y.name} = "stablehlo.broadcast_in_dim"(%{x.name}) {{
Expand Down
5 changes: 3 additions & 2 deletions src/slope/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,17 +189,18 @@ class dtypes:
bool: Final[DType] = DType(0, 1, "bool", "i1", bool)
int32: Final[DType] = DType(1, 4, "int32", "i32", np.int32)
int64: Final[DType] = DType(2, 8, "int64", "i64", np.int64)
uint64: Final[DType] = DType(2, 8, "uint64", "ui64", np.uint64)
float16: Final[DType] = DType(0, 2, "float16", "f16", np.float16)

all_dtypes = (bool, float16, float32, int8, int32, int64, uint8)
all_dtypes = (bool, float16, float32, int8, int32, int64, uint8, uint64)
name_dtype_map = {k.name: k for k in all_dtypes}
name_dtype_map_inv = {v: k for k, v in name_dtype_map.items()}
mlir_dtype_map = {k.mlir: k for k in all_dtypes}
mlir_dtype_map_inv = {v: k for k, v in mlir_dtype_map.items()}

@classmethod
def is_int(cls, dtype):
return dtype in (cls.uint8, cls.int8, cls.int32, cls.int64)
return dtype in (cls.uint8, cls.int8, cls.int32, cls.uint64, cls.int64)

@classmethod
def is_float(cls, dtype):
Expand Down

0 comments on commit c5766c4

Please sign in to comment.