diff --git a/tests/aot/globals_test.py b/tests/aot/globals_test.py index 18f5c1193..07618952c 100644 --- a/tests/aot/globals_test.py +++ b/tests/aot/globals_test.py @@ -40,8 +40,12 @@ def run(self, x=AbstractTensor(128, 20)): inst = GlobalModule(context=Context()) module_str = str(CompiledModule.get_mlir_module(inst)) print(module_str) - self.assertIn("util.global private @_params.classifier.weight", module_str) - self.assertIn("util.global private @_params.classifier.bias", module_str) + self.assertIn( + "util.global private @_params.classifier.weight", module_str + ) + self.assertIn( + "util.global private @_params.classifier.bias", module_str + ) def testGlobalLoadFromPyTree(self): m = SimpleParams() @@ -64,7 +68,8 @@ def read_params(self): module_str, ) self.assertIn( - "return %_params.classifier.weight, %_params.classifier.bias", module_str + "return %_params.classifier.weight, %_params.classifier.bias", + module_str, ) def testGlobalLoadFromPyLeaf(self): @@ -99,8 +104,12 @@ def update_params(me, updates=abstractify(params)): inst = GlobalModule(context=Context()) module_str = str(CompiledModule.get_mlir_module(inst)) print(module_str) - self.assertIn("util.global.store %arg0, @_params.classifier.weight", module_str) - self.assertIn("util.global.store %arg1, @_params.classifier.bias", module_str) + self.assertIn( + "util.global.store %arg0, @_params.classifier.weight", module_str + ) + self.assertIn( + "util.global.store %arg1, @_params.classifier.bias", module_str + ) def testGlobalStoreFromLeaf(self): m = SimpleParams() @@ -108,13 +117,17 @@ def testGlobalStoreFromLeaf(self): class GlobalModule(CompiledModule): params = export_parameters(m, initialize=False, mutable=True) - def update_bias(self, new_bias=abstractify(params["classifier.bias"])): + def update_bias( + self, new_bias=abstractify(params["classifier.bias"]) + ): self.params["classifier.bias"] = new_bias inst = GlobalModule(context=Context()) module_str = str(CompiledModule.get_mlir_module(inst)) print(module_str) - self.assertIn("util.global.store %arg0, @_params.classifier.bias", module_str) + self.assertIn( + "util.global.store %arg0, @_params.classifier.bias", module_str + ) def testExportSingleGlobalTensor(self): state_example = torch.randn(3, 11) @@ -129,7 +142,9 @@ def read_state(self): module_str = str(CompiledModule.get_mlir_module(inst)) print(module_str) self.assertIn("util.global private @_state0.global", module_str) - self.assertIn("%_state0.global = util.global.load @_state0.global", module_str) + self.assertIn( + "%_state0.global = util.global.load @_state0.global", module_str + ) self.assertIn("return %_state0.global", module_str) def testExportTreeGlobalTensors(self): @@ -155,10 +170,18 @@ def read_state(self): self.assertIn("util.global private @_state0.seq.1", module_str) self.assertIn("util.global private @_state0.seq.2", module_str) self.assertIn("util.global private @_state0.data", module_str) - self.assertIn("%_state0.data = util.global.load @_state0.data", module_str) - self.assertIn("%_state0.seq.0 = util.global.load @_state0.seq.0", module_str) - self.assertIn("%_state0.seq.1 = util.global.load @_state0.seq.1", module_str) - self.assertIn("%_state0.seq.2 = util.global.load @_state0.seq.2", module_str) + self.assertIn( + "%_state0.data = util.global.load @_state0.data", module_str + ) + self.assertIn( + "%_state0.seq.0 = util.global.load @_state0.seq.0", module_str + ) + self.assertIn( + "%_state0.seq.1 = util.global.load @_state0.seq.1", module_str + ) + self.assertIn( + "%_state0.seq.2 = util.global.load @_state0.seq.2", module_str + ) self.assertIn( "return %_state0.data, %_state0.seq.0, %_state0.seq.1, %_state0.seq.2", module_str, @@ -175,7 +198,9 @@ def testUpdateGlobalStateTree(self): } class SingleState(CompiledModule): - state0 = export_global_tree(state_example, mutable=True, initialize=False) + state0 = export_global_tree( + state_example, mutable=True, initialize=False + ) def read_state(self, updates=abstractify(state_example)): self.state0 = updates @@ -196,13 +221,15 @@ def testTensorUpdateGlobal(self): state_example = torch.randn(5, 20) update_example = torch.randn(1, 20) - class SingleState(CompiledModule): - state0 = export_global(state_example, mutable=True, initialize=False) + class UpdateState(CompiledModule): + state0 = export_global( + state_example, mutable=True, initialize=False + ) def tensor_update_state(self, update=abstractify(update_example)): return IREE.tensor_update(self.state0, update, 0, 0) - inst = SingleState(context=Context()) + inst = UpdateState(context=Context()) module_str = str(CompiledModule.get_mlir_module(inst)) print(module_str) self.assertIn( @@ -210,6 +237,28 @@ def tensor_update_state(self, update=abstractify(update_example)): module_str, ) + def testTensorUpdateGlobalReturnNone(self): + state_example = torch.randn(5, 20, 4) + update_example = torch.randn(1, 1, 4) + + class UpdateState(CompiledModule): + state0 = export_global( + state_example, mutable=True, initialize=False + ) + + def tensor_update_state(self, update=abstractify(update_example)): + thing = [] + self.state0 = IREE.tensor_update(self.state0, update, 4, 0, 0) + return None + + inst = UpdateState(context=Context()) + module_str = str(CompiledModule.get_mlir_module(inst)) + print(module_str) + self.assertIn( + "flow.tensor.update %arg0, %_state0.global[%c4, %c0, %c0] : tensor<1x1x4xf32> -> %_state0.global as tensor<5x20x4xf32>", + module_str, + ) + if __name__ == "__main__": logging.basicConfig(level=logging.DEBUG)