Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add runtime_platform_compatibility flag #3090

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
6 changes: 5 additions & 1 deletion py/torch_tensorrt/dynamo/_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ def compile(
hardware_compatible: bool = _defaults.HARDWARE_COMPATIBLE,
timing_cache_path: str = _defaults.TIMING_CACHE_PATH,
lazy_engine_init: bool = _defaults.LAZY_ENGINE_INIT,
enable_cross_compile_for_windows: bool = _defaults.ENABLE_CROSS_COMPILE_FOR_WINDOWS,
cache_built_engines: bool = _defaults.CACHE_BUILT_ENGINES,
reuse_cached_engines: bool = _defaults.REUSE_CACHED_ENGINES,
engine_cache_dir: str = _defaults.ENGINE_CACHE_DIR,
Expand Down Expand Up @@ -153,6 +154,7 @@ def compile(
hardware_compatible (bool): Build the TensorRT engines compatible with GPU architectures other than that of the GPU on which the engine was built (currently works for NVIDIA Ampere and newer)
timing_cache_path (str): Path to the timing cache if it exists (or) where it will be saved after compilation
lazy_engine_init (bool): Defer setting up engines until the compilation of all engines is complete. Can allow larger models with multiple graph breaks to compile but can lead to oversubscription of GPU memory at runtime.
enable_cross_compile_for_windows (bool): flag whether to enable cross-platform compatibility.
cache_built_engines (bool): Whether to save the compiled TRT engines to storage
reuse_cached_engines (bool): Whether to load the compiled TRT engines from storage
engine_cache_dir (Optional[str]): Directory to store the cached TRT engines
Expand Down Expand Up @@ -279,6 +281,7 @@ def compile(
"hardware_compatible": hardware_compatible,
"timing_cache_path": timing_cache_path,
"lazy_engine_init": lazy_engine_init,
"enable_cross_compile_for_windows": enable_cross_compile_for_windows,
"cache_built_engines": cache_built_engines,
"reuse_cached_engines": reuse_cached_engines,
}
Expand Down Expand Up @@ -473,7 +476,8 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool:
# Replace all FX Modules with TRT Modules
for name, trt_module in trt_modules.items():
setattr(partitioned_module, name, trt_module)
if settings.lazy_engine_init:
# for the cross compile for windows, we don't setup engine in linux
if settings.lazy_engine_init and not settings.enable_cross_compile_for_windows:
getattr(partitioned_module, name).setup_engine()

# Reset settings object to user specification after fallback to global partitioning mode
Expand Down
1 change: 1 addition & 0 deletions py/torch_tensorrt/dynamo/_defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
tempfile.gettempdir(), "torch_tensorrt_engine_cache", "timing_cache.bin"
)
LAZY_ENGINE_INIT = False
ENABLE_CROSS_COMPILE_FOR_WINDOWS = False
CACHE_BUILT_ENGINES = False
REUSE_CACHED_ENGINES = False
ENGINE_CACHE_DIR = os.path.join(tempfile.gettempdir(), "torch_tensorrt_engine_cache")
Expand Down
4 changes: 4 additions & 0 deletions py/torch_tensorrt/dynamo/_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
DLA_LOCAL_DRAM_SIZE,
DLA_SRAM_SIZE,
DRYRUN,
ENABLE_CROSS_COMPILE_FOR_WINDOWS,
ENABLE_EXPERIMENTAL_DECOMPOSITIONS,
ENABLED_PRECISIONS,
ENGINE_CAPABILITY,
Expand Down Expand Up @@ -76,6 +77,8 @@ class CompilationSettings:
output to a file if a string path is specified
hardware_compatible (bool): Build the TensorRT engines compatible with GPU architectures other than that of the GPU on which the engine was built (currently works for NVIDIA Ampere and newer)
timing_cache_path (str): Path to the timing cache if it exists (or) where it will be saved after compilation
enable_cross_compile_for_windows (bool): By default this is False means TensorRT engines can only be executed on the same platform where they were built.
True will enable cross-platform compatibility which allows the engine to be built on one platform and run on another platform
cache_built_engines (bool): Whether to save the compiled TRT engines to storage
reuse_cached_engines (bool): Whether to load the compiled TRT engines from storage
"""
Expand Down Expand Up @@ -110,6 +113,7 @@ class CompilationSettings:
hardware_compatible: bool = HARDWARE_COMPATIBLE
timing_cache_path: str = TIMING_CACHE_PATH
lazy_engine_init: bool = LAZY_ENGINE_INIT
enable_cross_compile_for_windows: bool = ENABLE_CROSS_COMPILE_FOR_WINDOWS
cache_built_engines: bool = CACHE_BUILT_ENGINES
reuse_cached_engines: bool = REUSE_CACHED_ENGINES

Expand Down
15 changes: 15 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import io
import logging
import os
import platform
import warnings
from datetime import datetime
from typing import (
Expand Down Expand Up @@ -123,6 +124,8 @@ def __init__(
dict()
)

self.compilation_settings = compilation_settings
self.validate_compile_settings()
# Data types for TRT Module output Tensors
self.output_dtypes = (
[dtype._from(o) for o in output_dtypes] if output_dtypes else None
Expand Down Expand Up @@ -184,6 +187,12 @@ def validate_compile_settings(self) -> None:
):
raise RuntimeError("Current platform doesn't support fast native int8!")

if self.compilation_settings.enable_cross_compile_for_windows:
if platform.system() != "Linux" or platform.architecture()[0] != "64bit":
raise RuntimeError(
f"Cross compile for windows is only supported on AMD 64bit Linux architecture, current platform: {platform.system()=}, {platform.architecture()[0]=}"
)

if (
dtype.f16 in self.compilation_settings.enabled_precisions
and not self.builder.platform_has_fast_fp16
Expand All @@ -207,6 +216,12 @@ def _populate_trt_builder_config(
trt.MemoryPoolType.WORKSPACE, self.compilation_settings.workspace_size
)

if self.compilation_settings.enable_cross_compile_for_windows:
builder_config.runtime_platform = trt.RuntimePlatform.WINDOWS_AMD64
_LOGGER.info(
"Setting runtime_platform as trt.RuntimePlatform.WINDOWS_AMD64"
)

if version.parse(trt.__version__) >= version.parse("8.2"):
builder_config.profiling_verbosity = (
trt.ProfilingVerbosity.DETAILED
Expand Down
3 changes: 1 addition & 2 deletions py/torch_tensorrt/dynamo/conversion/_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import logging
from typing import Any, List, Optional, Sequence

import tensorrt as trt
import torch
from torch.fx.experimental.proxy_tensor import unset_fake_temporarily
from torch_tensorrt._Device import Device
Expand All @@ -18,8 +19,6 @@
from torch_tensorrt.dynamo.runtime import PythonTorchTensorRTModule, TorchTensorRTModule
from torch_tensorrt.dynamo.utils import get_model_device, get_torch_inputs

import tensorrt as trt

logger = logging.getLogger(__name__)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,11 @@ def __init__(
self.weight_name_map = weight_name_map
self.target_platform = Platform.current_platform()

if self.serialized_engine is not None and not self.settings.lazy_engine_init:
if (
self.serialized_engine is not None
and not self.settings.lazy_engine_init
and not self.settings.enable_cross_compile_for_windows
):
self.setup_engine()

def setup_engine(self) -> None:
Expand Down
6 changes: 5 additions & 1 deletion py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,11 @@ def __init__(
self.serialized_engine = serialized_engine
self.engine = None

if serialized_engine and not self.settings.lazy_engine_init:
if (
serialized_engine
and not self.settings.lazy_engine_init
and not self.settings.enable_cross_compile_for_windows
):
self.setup_engine()

def _pack_engine_info(self) -> List[str | bytes]:
Expand Down
31 changes: 31 additions & 0 deletions tests/py/dynamo/runtime/test_000_compilation_settings.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
import os
import platform
import tempfile
import unittest

import pytest
import torch
import torch_tensorrt
from torch.testing._internal.common_utils import TestCase, run_tests
Expand All @@ -6,6 +12,31 @@


class TestEnableTRTFlags(TestCase):
@unittest.skipIf(
platform.system() != "Linux" or platform.architecture()[0] != "64bit",
"Cross compile for windows can only be enabled on linux 64 AMD platform",
)
@pytest.mark.unit
def test_enable_cross_compile_for_windows(self):
class Add(torch.nn.Module):
def forward(self, a, b):
return torch.add(a, b)

model = Add().eval().cuda()
inputs = [torch.randn(2, 3).cuda(), torch.randn(2, 3).cuda()]
trt_ep_path = os.path.join(tempfile.gettempdir(), "trt.ep")
compile_spec = {
"inputs": inputs,
"ir": "dynamo",
"min_block_size": 1,
"debug": True,
"enable_cross_compile_for_windows": True,
}
exp_program = torch_tensorrt.dynamo.trace(model, **compile_spec)
trt_gm = torch_tensorrt.dynamo.compile(exp_program, **compile_spec)
torch_tensorrt.save(trt_gm, trt_ep_path, inputs=inputs)
torch._dynamo.reset()

def test_toggle_build_args(self):
class AddSoftmax(torch.nn.Module):
def forward(self, x):
Expand Down
Loading