Skip to content

Commit

Permalink
add a test that updates without a return value
Browse files Browse the repository at this point in the history
doing a flow tensor update on a global but returning none
only works if you re-assign to the global
  • Loading branch information
dan-garvey committed Oct 13, 2023
1 parent 8b8aab2 commit 6afa47a
Showing 1 changed file with 22 additions and 2 deletions.
24 changes: 22 additions & 2 deletions tests/aot/globals_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,13 +196,33 @@ def testTensorUpdateGlobal(self):
state_example = torch.randn(5, 20)
update_example = torch.randn(1, 20)

class SingleState(CompiledModule):
class UpdateState(CompiledModule):
state0 = export_global(state_example, mutable=True, initialize=False)

def tensor_update_state(self, update=abstractify(update_example)):
return IREE.tensor_update(self.state0, update, 0, 0)

inst = SingleState(context=Context())
inst = UpdateState(context=Context())
module_str = str(CompiledModule.get_mlir_module(inst))
print(module_str)
self.assertIn(
"flow.tensor.update %arg0, %_state0.global[%c0, %c0] : tensor<1x20xf32> -> %_state0.global as tensor<5x20xf32>",
module_str,
)

def testTensorUpdateGlobalReturnNone(self):
state_example = torch.randn(5, 20, 4)
update_example = torch.randn(1, 1, 4)

class UpdateState(CompiledModule):
state0 = export_global(state_example, mutable=True, initialize=False)

def tensor_update_state(self, update=abstractify(update_example)):
thing = []
self.state0 = IREE.tensor_update(self.state0, update, 4, 0, 0)
return None

inst = UpdateState(context=Context())
module_str = str(CompiledModule.get_mlir_module(inst))
print(module_str)
self.assertIn(
Expand Down

0 comments on commit 6afa47a

Please sign in to comment.