Skip to content

Commit

Permalink
Add example test runner
Browse files Browse the repository at this point in the history
  • Loading branch information
stellaraccident committed Oct 9, 2023
1 parent 5d68a47 commit b2a291b
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 1 deletion.
2 changes: 1 addition & 1 deletion python/shark_turbine/aot/exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def export(mdl: torch.nn.Module, *example_args: torch.Tensor) -> ExportOutput:
"""One shot export of an nn.Module.
This is a very restrictive API vs the lower level `CompiledModule`
facility. It is suitable for one-short modules, with a single
facility. It is suitable for one-shot modules, with a single
entrypoint and static example arguments where no additional
configuration is needed for mutable parameters/buffers or state
management. Dynamic shape constraints are also not presently
Expand Down
28 changes: 28 additions & 0 deletions tests/examples/aot_mlp_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Copyright 2023 Nod Labs, Inc
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import logging
from pathlib import Path
import sys
import subprocess
import unittest

REPO_DIR = Path(__file__).resolve().parent.parent.parent


def _run(local_path: str):
path = REPO_DIR / local_path
subprocess.check_call([sys.executable, str(path)])


class AOTMLPTest(unittest.TestCase):
def testMLPExportSimple(self):
_run("examples/aot_mlp/mlp_export_simple.py")


if __name__ == "__main__":
logging.basicConfig(level=logging.DEBUG)
unittest.main()

0 comments on commit b2a291b

Please sign in to comment.