Skip to content

Commit

Permalink
add a test that updates without a return value (#96)
Browse files Browse the repository at this point in the history
doing a flow tensor update on a global but returning none seems to cause
the compiler to optimize out the update
  • Loading branch information
dan-garvey authored Oct 13, 2023
1 parent f3fc04b commit 9ae8eec
Showing 1 changed file with 65 additions and 16 deletions.
81 changes: 65 additions & 16 deletions tests/aot/globals_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,12 @@ def run(self, x=AbstractTensor(128, 20)):
inst = GlobalModule(context=Context())
module_str = str(CompiledModule.get_mlir_module(inst))
print(module_str)
self.assertIn("util.global private @_params.classifier.weight", module_str)
self.assertIn("util.global private @_params.classifier.bias", module_str)
self.assertIn(
"util.global private @_params.classifier.weight", module_str
)
self.assertIn(
"util.global private @_params.classifier.bias", module_str
)

def testGlobalLoadFromPyTree(self):
m = SimpleParams()
Expand All @@ -64,7 +68,8 @@ def read_params(self):
module_str,
)
self.assertIn(
"return %_params.classifier.weight, %_params.classifier.bias", module_str
"return %_params.classifier.weight, %_params.classifier.bias",
module_str,
)

def testGlobalLoadFromPyLeaf(self):
Expand Down Expand Up @@ -99,22 +104,30 @@ def update_params(me, updates=abstractify(params)):
inst = GlobalModule(context=Context())
module_str = str(CompiledModule.get_mlir_module(inst))
print(module_str)
self.assertIn("util.global.store %arg0, @_params.classifier.weight", module_str)
self.assertIn("util.global.store %arg1, @_params.classifier.bias", module_str)
self.assertIn(
"util.global.store %arg0, @_params.classifier.weight", module_str
)
self.assertIn(
"util.global.store %arg1, @_params.classifier.bias", module_str
)

def testGlobalStoreFromLeaf(self):
m = SimpleParams()

class GlobalModule(CompiledModule):
params = export_parameters(m, initialize=False, mutable=True)

def update_bias(self, new_bias=abstractify(params["classifier.bias"])):
def update_bias(
self, new_bias=abstractify(params["classifier.bias"])
):
self.params["classifier.bias"] = new_bias

inst = GlobalModule(context=Context())
module_str = str(CompiledModule.get_mlir_module(inst))
print(module_str)
self.assertIn("util.global.store %arg0, @_params.classifier.bias", module_str)
self.assertIn(
"util.global.store %arg0, @_params.classifier.bias", module_str
)

def testExportSingleGlobalTensor(self):
state_example = torch.randn(3, 11)
Expand All @@ -129,7 +142,9 @@ def read_state(self):
module_str = str(CompiledModule.get_mlir_module(inst))
print(module_str)
self.assertIn("util.global private @_state0.global", module_str)
self.assertIn("%_state0.global = util.global.load @_state0.global", module_str)
self.assertIn(
"%_state0.global = util.global.load @_state0.global", module_str
)
self.assertIn("return %_state0.global", module_str)

def testExportTreeGlobalTensors(self):
Expand All @@ -155,10 +170,18 @@ def read_state(self):
self.assertIn("util.global private @_state0.seq.1", module_str)
self.assertIn("util.global private @_state0.seq.2", module_str)
self.assertIn("util.global private @_state0.data", module_str)
self.assertIn("%_state0.data = util.global.load @_state0.data", module_str)
self.assertIn("%_state0.seq.0 = util.global.load @_state0.seq.0", module_str)
self.assertIn("%_state0.seq.1 = util.global.load @_state0.seq.1", module_str)
self.assertIn("%_state0.seq.2 = util.global.load @_state0.seq.2", module_str)
self.assertIn(
"%_state0.data = util.global.load @_state0.data", module_str
)
self.assertIn(
"%_state0.seq.0 = util.global.load @_state0.seq.0", module_str
)
self.assertIn(
"%_state0.seq.1 = util.global.load @_state0.seq.1", module_str
)
self.assertIn(
"%_state0.seq.2 = util.global.load @_state0.seq.2", module_str
)
self.assertIn(
"return %_state0.data, %_state0.seq.0, %_state0.seq.1, %_state0.seq.2",
module_str,
Expand All @@ -175,7 +198,9 @@ def testUpdateGlobalStateTree(self):
}

class SingleState(CompiledModule):
state0 = export_global_tree(state_example, mutable=True, initialize=False)
state0 = export_global_tree(
state_example, mutable=True, initialize=False
)

def read_state(self, updates=abstractify(state_example)):
self.state0 = updates
Expand All @@ -196,20 +221,44 @@ def testTensorUpdateGlobal(self):
state_example = torch.randn(5, 20)
update_example = torch.randn(1, 20)

class SingleState(CompiledModule):
state0 = export_global(state_example, mutable=True, initialize=False)
class UpdateState(CompiledModule):
state0 = export_global(
state_example, mutable=True, initialize=False
)

def tensor_update_state(self, update=abstractify(update_example)):
return IREE.tensor_update(self.state0, update, 0, 0)

inst = SingleState(context=Context())
inst = UpdateState(context=Context())
module_str = str(CompiledModule.get_mlir_module(inst))
print(module_str)
self.assertIn(
"flow.tensor.update %arg0, %_state0.global[%c0, %c0] : tensor<1x20xf32> -> %_state0.global as tensor<5x20xf32>",
module_str,
)

def testTensorUpdateGlobalReturnNone(self):
state_example = torch.randn(5, 20, 4)
update_example = torch.randn(1, 1, 4)

class UpdateState(CompiledModule):
state0 = export_global(
state_example, mutable=True, initialize=False
)

def tensor_update_state(self, update=abstractify(update_example)):
thing = []
self.state0 = IREE.tensor_update(self.state0, update, 4, 0, 0)
return None

inst = UpdateState(context=Context())
module_str = str(CompiledModule.get_mlir_module(inst))
print(module_str)
self.assertIn(
"flow.tensor.update %arg0, %_state0.global[%c4, %c0, %c0] : tensor<1x1x4xf32> -> %_state0.global as tensor<5x20x4xf32>",
module_str,
)


if __name__ == "__main__":
logging.basicConfig(level=logging.DEBUG)
Expand Down

0 comments on commit 9ae8eec

Please sign in to comment.