From b2a291bf733fe28202df67fb7ae42787377149ce Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Mon, 9 Oct 2023 12:07:59 -0700 Subject: [PATCH] Add example test runner --- python/shark_turbine/aot/exporter.py | 2 +- tests/examples/aot_mlp_test.py | 28 ++++++++++++++++++++++++++++ 2 files changed, 29 insertions(+), 1 deletion(-) create mode 100644 tests/examples/aot_mlp_test.py diff --git a/python/shark_turbine/aot/exporter.py b/python/shark_turbine/aot/exporter.py index 8382979a2..f93a3e153 100644 --- a/python/shark_turbine/aot/exporter.py +++ b/python/shark_turbine/aot/exporter.py @@ -155,7 +155,7 @@ def export(mdl: torch.nn.Module, *example_args: torch.Tensor) -> ExportOutput: """One shot export of an nn.Module. This is a very restrictive API vs the lower level `CompiledModule` - facility. It is suitable for one-short modules, with a single + facility. It is suitable for one-shot modules, with a single entrypoint and static example arguments where no additional configuration is needed for mutable parameters/buffers or state management. Dynamic shape constraints are also not presently diff --git a/tests/examples/aot_mlp_test.py b/tests/examples/aot_mlp_test.py new file mode 100644 index 000000000..d55e0b9c6 --- /dev/null +++ b/tests/examples/aot_mlp_test.py @@ -0,0 +1,28 @@ +# Copyright 2023 Nod Labs, Inc +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import logging +from pathlib import Path +import sys +import subprocess +import unittest + +REPO_DIR = Path(__file__).resolve().parent.parent.parent + + +def _run(local_path: str): + path = REPO_DIR / local_path + subprocess.check_call([sys.executable, str(path)]) + + +class AOTMLPTest(unittest.TestCase): + def testMLPExportSimple(self): + _run("examples/aot_mlp/mlp_export_simple.py") + + +if __name__ == "__main__": + logging.basicConfig(level=logging.DEBUG) + unittest.main()