Skip to content

Commit

Permalink
Fix unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Lunderberg committed Sep 3, 2024
1 parent 055996b commit 978b73b
Showing 1 changed file with 38 additions and 7 deletions.
45 changes: 38 additions & 7 deletions tests/python/relax/test_transform_legalize_ops_create_datatype.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,14 +160,14 @@ def test_full_like():
@tvm.script.ir_module
class FullLike:
@R.function
def main(x: R.Tensor((2, 3), "int32"), v: R.Tensor((), "float32")) -> R.Tensor((2, 3), "float32"):
def main(x: R.Tensor((2, 3), "float32"), v: R.Tensor((), "float32")) -> R.Tensor((2, 3), "float32"):
gv: R.Tensor((2, 3), "float32") = R.full_like(x, v)
return gv

@tvm.script.ir_module
class Expected:
@R.function
def main(x: R.Tensor((2, 3), "int32"), v: R.Tensor((), "float32")) -> R.Tensor((2, 3), "float32"):
def main(x: R.Tensor((2, 3), "float32"), v: R.Tensor((), "float32")) -> R.Tensor((2, 3), "float32"):
gv = R.call_tir(Expected.full, (v,), R.Tensor((2, 3), dtype="float32"))
return gv

Expand All @@ -191,14 +191,14 @@ def test_full_like_constant_scalar_fill_value():
@tvm.script.ir_module
class FullLike:
@R.function
def main(x: R.Tensor((2, 3), "int32")) -> R.Tensor((2, 3), "float32"):
def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"):
gv: R.Tensor((2, 3), "float32") = R.full_like(x, R.const(-5, "float32"))
return gv

@tvm.script.ir_module
class Expected:
@R.function
def main(x: R.Tensor((2, 3), "int32")) -> R.Tensor((2, 3), "float32"):
def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"):
gv = R.call_tir(Expected.full, R.tuple(), R.Tensor((2, 3), dtype="float32"))
return gv

Expand All @@ -217,7 +217,7 @@ def full(T_full: T.Buffer((T.int64(2), T.int64(3)), "float32")):
tvm.ir.assert_structural_equal(mod, Expected)


def test_full_like_different_dtype():
def test_full_like_different_explicit_dtype():
# fmt: off
@tvm.script.ir_module
class FullLike:
Expand Down Expand Up @@ -248,12 +248,43 @@ def full(rxplaceholder: T.Buffer((), "float32"), T_full: T.Buffer((T.int64(2), T
tvm.ir.assert_structural_equal(mod, Expected)


def test_full_like_different_inferred_dtype():
# fmt: off
@tvm.script.ir_module
class FullLike:
@R.function
def main(x: R.Tensor((2, 3), "int32"), v: R.Tensor((), "float32")):
gv = R.full_like(x, v)
return gv

@tvm.script.ir_module
class Expected:
@R.function
def main(x: R.Tensor((2, 3), "int32"), v: R.Tensor((), "float32")) -> R.Tensor((2, 3), "int32"):
gv = R.call_tir(Expected.full, (v,), R.Tensor((2, 3), dtype="int32"))
return gv

@T.prim_func(private=True)
def full(rxplaceholder: T.Buffer((), "float32"), T_full: T.Buffer((T.int64(2), T.int64(3)), "int32")):
T.func_attr({"tir.noalias": True})
for i0, i1 in T.grid(T.int64(2), T.int64(3)):
with T.block("T_full"):
ax0, ax1 = T.axis.remap("SS", [i0, i1])
T.reads(rxplaceholder[()])
T.writes(T_full[ax0, ax1])
T_full[ax0, ax1] = T.Cast("int32", rxplaceholder[()])
# fmt: on

mod = LegalizeOps()(FullLike)
tvm.ir.assert_structural_equal(mod, Expected)


def test_full_like_symbolic():
# fmt: off
@tvm.script.ir_module
class FullLike:
@R.function
def main(x: R.Tensor(("m", "n"), "int32"), v: R.Tensor((), "float32")) -> R.Tensor(("m", "n"), "float32"):
def main(x: R.Tensor(("m", "n"), "float32"), v: R.Tensor((), "float32")) -> R.Tensor(("m", "n"), "float32"):
m = T.int64()
n = T.int64()
gv: R.Tensor((m, n), "float32") = R.full_like(x, v)
Expand All @@ -262,7 +293,7 @@ def main(x: R.Tensor(("m", "n"), "int32"), v: R.Tensor((), "float32")) -> R.Tens
@tvm.script.ir_module
class Expected:
@R.function
def main(x: R.Tensor(("m", "n"), "int32"), v: R.Tensor((), "float32")) -> R.Tensor(("m", "n"), "float32"):
def main(x: R.Tensor(("m", "n"), "float32"), v: R.Tensor((), "float32")) -> R.Tensor(("m", "n"), "float32"):
m = T.int64()
n = T.int64()
gv = R.call_tir(Expected.full, (v,), R.Tensor((m, n), dtype="float32"))
Expand Down

0 comments on commit 978b73b

Please sign in to comment.