Skip to content

Commit

Permalink
Merge pull request #233 from google-ai-edge/optionalpytorch
Browse files Browse the repository at this point in the history
Minor improvements on making torch optional.
  • Loading branch information
jinjingforever authored Nov 1, 2024
2 parents 1872e26 + 47eb847 commit 6aa67b6
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 9 deletions.
7 changes: 7 additions & 0 deletions src/server/package/src/model_explorer/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,13 @@ def add_model_from_pytorch(
exported_program: the ExportedProgram from torch.export.export.
settings: The settings that config the visualization.
"""

if torch is None:
raise ImportError(
'`torch` not found. Please install it via `pip install torch`, '
'and restart the Model Explorer server.'
)

# Convert the given model to model explorer graphs.
print('Converting pytorch model to model explorer graphs...')
adapter = PytorchExportedProgramAdapterImpl(exported_program, settings)
Expand Down
27 changes: 18 additions & 9 deletions src/server/package/src/model_explorer/extension_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,11 @@
from importlib import import_module
from typing import Any, Dict, Union

try:
import torch
except ImportError:
torch = None

from .adapter_runner import AdapterRunner
from .consts import MODULE_NAME
from .extension_class_processor import ExtensionClassProcessor
Expand All @@ -29,15 +34,19 @@


class ExtensionManager(object, metaclass=Singleton):
BUILTIN_ADAPTER_MODULES: list[str] = [
'.builtin_tflite_flatbuffer_adapter',
'.builtin_tflite_mlir_adapter',
'.builtin_tf_mlir_adapter',
'.builtin_tf_direct_adapter',
'.builtin_graphdef_adapter',
'.builtin_pytorch_exportedprogram_adapter',
'.builtin_mlir_adapter',
]
BUILTIN_ADAPTER_MODULES: list[str] = (
[
'.builtin_tflite_flatbuffer_adapter',
'.builtin_tflite_mlir_adapter',
'.builtin_tf_mlir_adapter',
'.builtin_tf_direct_adapter',
'.builtin_graphdef_adapter',
]
+ (['.builtin_pytorch_exportedprogram_adapter'] if torch else [])
+ [
'.builtin_mlir_adapter',
]
)

CACHED_REGISTERED_EXTENSIONS: Dict[str, RegisteredExtension] = {}

Expand Down

0 comments on commit 6aa67b6

Please sign in to comment.