Skip to content

Commit

Permalink
Fix assert on aot.export() with CompiledModule subclasses. (#615)
Browse files Browse the repository at this point in the history
I was hitting an assert trying to pass a subclass of `CompiledModule`
through the `aot.export()` API. The code looked correct, but the
metaclass code here:
https://github.com/nod-ai/SHARK-Turbine/blob/b22dc7f77cab3bc7a7e29ae73398468799edf713/core/shark_turbine/aot/compiled_module.py#L429
could be affecting the check. I looked through
https://stackoverflow.com/questions/33347131/determine-if-a-class-in-python-is-a-metaclass
for ideas on how to fix.
  • Loading branch information
ScottTodd authored Apr 12, 2024
1 parent b22dc7f commit 59bc67d
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 2 deletions.
5 changes: 3 additions & 2 deletions core/shark_turbine/aot/exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,9 +282,10 @@ def export(
{(function_name or "main"): exported_program},
export_name=module_name or "module",
)
else:
assert issubclass(type(mdl), CompiledModule)
elif issubclass(mdl, CompiledModule):
TransformedModule = mdl
else:
raise TypeError(f"mdl argument (type: {type(mdl)}) is not a supported type")

session = Session()
# There are some bugs with respect to Session/context interop that we
Expand Down
16 changes: 16 additions & 0 deletions core/tests/aot/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,22 @@ def testTorchExportedProgram(self):
asm,
)

def testCompiledModuleExportedProgram(self):
class BasicModule(CompiledModule):
...

exported = export(BasicModule)
module_str = str(exported.mlir_module)
print(module_str)
self.assertIn("module @basic", module_str)

def testUnsupportedExportedProgram(self):
class UnsupportedExportType:
...

with self.assertRaises(TypeError):
export(UnsupportedExportType)


class SimpleParams(nn.Module):
def __init__(self):
Expand Down

0 comments on commit 59bc67d

Please sign in to comment.