diff --git a/export/orbax/export/obm_export.py b/export/orbax/export/obm_export.py index 8680e73b9..abf3d7cba 100644 --- a/export/orbax/export/obm_export.py +++ b/export/orbax/export/obm_export.py @@ -42,7 +42,7 @@ class ObmExport(export_base.ExportBase): def __init__( self, - module: jax_module.JaxModule, + module: jax_module.JaxModule | None, serving_configs: Sequence[osc.ServingConfig], ): """Initializes the ObmExport class."""