From 9fde9e0438fb2954e7c8eb812f7a4cdb5c38726f Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Thu, 15 Aug 2024 11:03:46 -0700 Subject: [PATCH 1/5] add runtime_platform_compatibility flag --- .gitignore | 3 ++- MODULE.bazel.lock | 8 ++++---- py/torch_tensorrt/dynamo/_compiler.py | 3 +++ py/torch_tensorrt/dynamo/_defaults.py | 1 + py/torch_tensorrt/dynamo/_settings.py | 4 ++++ .../dynamo/conversion/_TRTInterpreter.py | 18 ++++++++++++++++++ .../runtime/test_compilation_settings.py | 1 + 7 files changed, 33 insertions(+), 5 deletions(-) diff --git a/.gitignore b/.gitignore index c8dd4c5308..16e4f4f838 100644 --- a/.gitignore +++ b/.gitignore @@ -73,4 +73,5 @@ wheelhouse/ tests/py/dynamo/models/*.ts tests/py/dynamo/models/*.ep *.deb -*.tar.xz \ No newline at end of file +*.tar.xz +MODULE.bazel.lock \ No newline at end of file diff --git a/MODULE.bazel.lock b/MODULE.bazel.lock index 20d83056f6..10187ecee3 100644 --- a/MODULE.bazel.lock +++ b/MODULE.bazel.lock @@ -48,8 +48,8 @@ "https://bcr.bazel.build/modules/rules_cc/0.0.9/source.json": "1f1ba6fea244b616de4a554a0f4983c91a9301640c8fe0dd1d410254115c8430", "https://bcr.bazel.build/modules/rules_java/4.0.0/MODULE.bazel": "5a78a7ae82cd1a33cef56dc578c7d2a46ed0dca12643ee45edbb8417899e6f74", "https://bcr.bazel.build/modules/rules_java/7.1.0/MODULE.bazel": "30d9135a2b6561c761bd67bd4990da591e6bdc128790ce3e7afd6a3558b2fb64", - "https://bcr.bazel.build/modules/rules_java/7.6.1/MODULE.bazel": "2f14b7e8a1aa2f67ae92bc69d1ec0fa8d9f827c4e17ff5e5f02e91caa3b2d0fe", - "https://bcr.bazel.build/modules/rules_java/7.6.1/source.json": "8f3f3076554e1558e8e468b2232991c510ecbcbed9e6f8c06ac31c93bcf38362", + "https://bcr.bazel.build/modules/rules_java/7.6.5/MODULE.bazel": "481164be5e02e4cab6e77a36927683263be56b7e36fef918b458d7a8a1ebadb1", + "https://bcr.bazel.build/modules/rules_java/7.6.5/source.json": "a805b889531d1690e3c72a7a7e47a870d00323186a9904b36af83aa3d053ee8d", "https://bcr.bazel.build/modules/rules_jvm_external/4.4.2/MODULE.bazel": "a56b85e418c83eb1839819f0b515c431010160383306d13ec21959ac412d2fe7", "https://bcr.bazel.build/modules/rules_jvm_external/5.1/MODULE.bazel": "33f6f999e03183f7d088c9be518a63467dfd0be94a11d0055fe2d210f89aa909", "https://bcr.bazel.build/modules/rules_jvm_external/5.1/source.json": "5abb45cc9beb27b77aec6a65a11855ef2b55d95dfdc358e9f312b78ae0ba32d5", @@ -73,8 +73,8 @@ "https://bcr.bazel.build/modules/upb/0.0.0-20230516-61a97ef/source.json": "b2150404947339e8b947c6b16baa39fa75657f4ddec5e37272c7b11c7ab533bc", "https://bcr.bazel.build/modules/zlib/1.2.11/MODULE.bazel": "07b389abc85fdbca459b69e2ec656ae5622873af3f845e1c9d80fe179f3effa0", "https://bcr.bazel.build/modules/zlib/1.2.12/MODULE.bazel": "3b1a8834ada2a883674be8cbd36ede1b6ec481477ada359cd2d3ddc562340b27", - "https://bcr.bazel.build/modules/zlib/1.3/MODULE.bazel": "6a9c02f19a24dcedb05572b2381446e27c272cd383aed11d41d99da9e3167a72", - "https://bcr.bazel.build/modules/zlib/1.3/source.json": "b6b43d0737af846022636e6e255fd4a96fee0d34f08f3830e6e0bac51465c37c" + "https://bcr.bazel.build/modules/zlib/1.3.1.bcr.3/MODULE.bazel": "af322bc08976524477c79d1e45e241b6efbeb918c497e8840b8ab116802dda79", + "https://bcr.bazel.build/modules/zlib/1.3.1.bcr.3/source.json": "2be409ac3c7601245958cd4fcdff4288be79ed23bd690b4b951f500d54ee6e7d" }, "selectedYankedVersions": {}, "moduleExtensions": { diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index c97c3a6229..bd87284be5 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -83,6 +83,7 @@ def compile( hardware_compatible: bool = _defaults.HARDWARE_COMPATIBLE, timing_cache_path: str = _defaults.TIMING_CACHE_PATH, lazy_engine_init: bool = _defaults.LAZY_ENGINE_INIT, + enable_cross_platform_compatibility: bool = _defaults.ENABLE_CROSS_PLATFORM_COMPATIBILITY, **kwargs: Any, ) -> torch.fx.GraphModule: """Compile an ExportedProgram module for NVIDIA GPUs using TensorRT @@ -148,6 +149,7 @@ def compile( hardware_compatible (bool): Build the TensorRT engines compatible with GPU architectures other than that of the GPU on which the engine was built (currently works for NVIDIA Ampere and newer) timing_cache_path (str): Path to the timing cache if it exists (or) where it will be saved after compilation lazy_engine_init (bool): Defer setting up engines until the compilation of all engines is complete. Can allow larger models with multiple graph breaks to compile but can lead to oversubscription of GPU memory at runtime. + enable_cross_platform_compatibility (bool): flag whether to enable cross-platform compatibility. **kwargs: Any, Returns: torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT @@ -257,6 +259,7 @@ def compile( "hardware_compatible": hardware_compatible, "timing_cache_path": timing_cache_path, "lazy_engine_init": lazy_engine_init, + "enable_cross_platform_compatibility": enable_cross_platform_compatibility, } settings = CompilationSettings(**compilation_options) diff --git a/py/torch_tensorrt/dynamo/_defaults.py b/py/torch_tensorrt/dynamo/_defaults.py index 2696e26936..15dfcefe01 100644 --- a/py/torch_tensorrt/dynamo/_defaults.py +++ b/py/torch_tensorrt/dynamo/_defaults.py @@ -33,6 +33,7 @@ SUPPORTED_KERNEL_PRECISIONS = {dtype.f32, dtype.f16, dtype.bf16, dtype.i8, dtype.f8} TIMING_CACHE_PATH = os.path.join(tempfile.gettempdir(), "timing_cache.bin") LAZY_ENGINE_INIT = False +ENABLE_CROSS_PLATFORM_COMPATIBILITY = False def default_device() -> Device: diff --git a/py/torch_tensorrt/dynamo/_settings.py b/py/torch_tensorrt/dynamo/_settings.py index 4a9792d3dc..0cdc8d4c35 100644 --- a/py/torch_tensorrt/dynamo/_settings.py +++ b/py/torch_tensorrt/dynamo/_settings.py @@ -12,6 +12,7 @@ DLA_LOCAL_DRAM_SIZE, DLA_SRAM_SIZE, DRYRUN, + ENABLE_CROSS_PLATFORM_COMPATIBILITY, ENABLE_EXPERIMENTAL_DECOMPOSITIONS, ENABLED_PRECISIONS, ENGINE_CAPABILITY, @@ -74,6 +75,8 @@ class CompilationSettings: output to a file if a string path is specified hardware_compatible (bool): Build the TensorRT engines compatible with GPU architectures other than that of the GPU on which the engine was built (currently works for NVIDIA Ampere and newer) timing_cache_path (str): Path to the timing cache if it exists (or) where it will be saved after compilation + enable_cross_platform_compatibility (bool): By default this is False means TensorRT engines can only be executed on the same platform where they were built. + True will enable cross-platform compatibility which allows the engine to be built on one platform and run on another platform """ enabled_precisions: Set[dtype] = field(default_factory=lambda: ENABLED_PRECISIONS) @@ -106,3 +109,4 @@ class CompilationSettings: hardware_compatible: bool = HARDWARE_COMPATIBLE timing_cache_path: str = TIMING_CACHE_PATH lazy_engine_init: bool = LAZY_ENGINE_INIT + enable_cross_platform_compatibility: bool = ENABLE_CROSS_PLATFORM_COMPATIBILITY diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py index 9a3cace599..d641fc47b7 100644 --- a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py +++ b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py @@ -1,6 +1,7 @@ import io import logging import os +import platform import warnings from datetime import datetime from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence, Set, Tuple @@ -180,6 +181,23 @@ def _populate_trt_builder_config( builder_config.set_memory_pool_limit( trt.MemoryPoolType.WORKSPACE, self.compilation_settings.workspace_size ) + if version.parse(trt.__version__) >= version.parse("10.2"): + if self.compilation_settings.enable_cross_platform_compatibility: + # currently this flag can only be enabled when building engines on Linux AMD64 platforms + # and target platform for engine execution as Windows AMD64 system. + # https://github.com/NVIDIA/TensorRT/blob/c5b9de37f7ef9034e2efc621c664145c7c12436e/include/NvInfer.h#L8257 + if ( + platform.system() == "Linux" + and platform.architecture()[0] == "64bit" + ): + _LOGGER.info( + f"Setting cross platform compatibility to {self.compilation_settings.enable_cross_platform_compatibility}" + ) + builder_config.runtime_platform = trt.RuntimePlatform.WINDOWS_AMD64 + else: + warnings.warn( + f"{platform.system()=},{platform.architecture()[0]=}, Cross platform compatibility can only be enabled when building engines on Linux AMD64 platform and running the engine on Windows AMD64 platform." + ) if version.parse(trt.__version__) >= version.parse("8.2"): builder_config.profiling_verbosity = ( diff --git a/tests/py/dynamo/runtime/test_compilation_settings.py b/tests/py/dynamo/runtime/test_compilation_settings.py index 47f700038a..b7d967e01a 100644 --- a/tests/py/dynamo/runtime/test_compilation_settings.py +++ b/tests/py/dynamo/runtime/test_compilation_settings.py @@ -37,6 +37,7 @@ def forward(self, x): num_avg_timing_iters=5, workspace_size=1 << 10, truncate_double=True, + enable_cross_platform_compatibility=True, ) optimized_model_results = optimized_model(*inputs).detach().cpu() From c298fc0ae75e7fa9cc869839e5287674150639bc Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Thu, 15 Aug 2024 14:13:44 -0700 Subject: [PATCH 2/5] test --- tests/py/dynamo/runtime/test_000_compilation_settings.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/py/dynamo/runtime/test_000_compilation_settings.py b/tests/py/dynamo/runtime/test_000_compilation_settings.py index b7d967e01a..47f700038a 100644 --- a/tests/py/dynamo/runtime/test_000_compilation_settings.py +++ b/tests/py/dynamo/runtime/test_000_compilation_settings.py @@ -37,7 +37,6 @@ def forward(self, x): num_avg_timing_iters=5, workspace_size=1 << 10, truncate_double=True, - enable_cross_platform_compatibility=True, ) optimized_model_results = optimized_model(*inputs).detach().cpu() From ed246c24ffbc4f40b39ba693ada57787962ff666 Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Thu, 15 Aug 2024 14:17:47 -0700 Subject: [PATCH 3/5] test --- py/torch_tensorrt/_compile.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/py/torch_tensorrt/_compile.py b/py/torch_tensorrt/_compile.py index 0505d90f89..910e98fa41 100644 --- a/py/torch_tensorrt/_compile.py +++ b/py/torch_tensorrt/_compile.py @@ -534,4 +534,4 @@ def save( exp_program = torch.export.export( module, tuple(arg_inputs), kwargs=kwarg_inputs, strict=False ) - torch.export.save(exp_program, file_path) \ No newline at end of file + torch.export.save(exp_program, file_path) From 3a31bbbfe9447a02e7ab2df76f1b6e241c8ef1ca Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Tue, 20 Aug 2024 11:07:21 -0700 Subject: [PATCH 4/5] change flag to cross compile for windows --- py/torch_tensorrt/dynamo/_compiler.py | 9 ++--- py/torch_tensorrt/dynamo/_defaults.py | 2 +- py/torch_tensorrt/dynamo/_settings.py | 6 ++-- .../dynamo/conversion/_TRTInterpreter.py | 36 +++++++++---------- .../dynamo/conversion/_conversion.py | 17 ++++----- .../runtime/_PythonTorchTensorRTModule.py | 6 +++- .../dynamo/runtime/_TorchTensorRTModule.py | 6 +++- .../runtime/test_000_compilation_settings.py | 31 ++++++++++++++++ 8 files changed, 77 insertions(+), 36 deletions(-) diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index bd87284be5..5c3200461f 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -83,7 +83,7 @@ def compile( hardware_compatible: bool = _defaults.HARDWARE_COMPATIBLE, timing_cache_path: str = _defaults.TIMING_CACHE_PATH, lazy_engine_init: bool = _defaults.LAZY_ENGINE_INIT, - enable_cross_platform_compatibility: bool = _defaults.ENABLE_CROSS_PLATFORM_COMPATIBILITY, + enable_cross_compile_for_windows: bool = _defaults.ENABLE_CROSS_COMPILE_FOR_WINDOWS, **kwargs: Any, ) -> torch.fx.GraphModule: """Compile an ExportedProgram module for NVIDIA GPUs using TensorRT @@ -149,7 +149,7 @@ def compile( hardware_compatible (bool): Build the TensorRT engines compatible with GPU architectures other than that of the GPU on which the engine was built (currently works for NVIDIA Ampere and newer) timing_cache_path (str): Path to the timing cache if it exists (or) where it will be saved after compilation lazy_engine_init (bool): Defer setting up engines until the compilation of all engines is complete. Can allow larger models with multiple graph breaks to compile but can lead to oversubscription of GPU memory at runtime. - enable_cross_platform_compatibility (bool): flag whether to enable cross-platform compatibility. + enable_cross_compile_for_windows (bool): flag whether to enable cross-platform compatibility. **kwargs: Any, Returns: torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT @@ -259,7 +259,7 @@ def compile( "hardware_compatible": hardware_compatible, "timing_cache_path": timing_cache_path, "lazy_engine_init": lazy_engine_init, - "enable_cross_platform_compatibility": enable_cross_platform_compatibility, + "enable_cross_compile_for_windows": enable_cross_compile_for_windows, } settings = CompilationSettings(**compilation_options) @@ -487,7 +487,8 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool: # Replace all FX Modules with TRT Modules for name, trt_module in trt_modules.items(): setattr(partitioned_module, name, trt_module) - if settings.lazy_engine_init: + # for the cross compile for windows, we don't setup engine in linux + if settings.lazy_engine_init and not settings.enable_cross_compile_for_windows: getattr(partitioned_module, name).setup_engine() # Reset settings object to user specification after fallback to global partitioning mode diff --git a/py/torch_tensorrt/dynamo/_defaults.py b/py/torch_tensorrt/dynamo/_defaults.py index 15dfcefe01..189b127df6 100644 --- a/py/torch_tensorrt/dynamo/_defaults.py +++ b/py/torch_tensorrt/dynamo/_defaults.py @@ -33,7 +33,7 @@ SUPPORTED_KERNEL_PRECISIONS = {dtype.f32, dtype.f16, dtype.bf16, dtype.i8, dtype.f8} TIMING_CACHE_PATH = os.path.join(tempfile.gettempdir(), "timing_cache.bin") LAZY_ENGINE_INIT = False -ENABLE_CROSS_PLATFORM_COMPATIBILITY = False +ENABLE_CROSS_COMPILE_FOR_WINDOWS = False def default_device() -> Device: diff --git a/py/torch_tensorrt/dynamo/_settings.py b/py/torch_tensorrt/dynamo/_settings.py index 0cdc8d4c35..a7701c29f2 100644 --- a/py/torch_tensorrt/dynamo/_settings.py +++ b/py/torch_tensorrt/dynamo/_settings.py @@ -12,7 +12,7 @@ DLA_LOCAL_DRAM_SIZE, DLA_SRAM_SIZE, DRYRUN, - ENABLE_CROSS_PLATFORM_COMPATIBILITY, + ENABLE_CROSS_COMPILE_FOR_WINDOWS, ENABLE_EXPERIMENTAL_DECOMPOSITIONS, ENABLED_PRECISIONS, ENGINE_CAPABILITY, @@ -75,7 +75,7 @@ class CompilationSettings: output to a file if a string path is specified hardware_compatible (bool): Build the TensorRT engines compatible with GPU architectures other than that of the GPU on which the engine was built (currently works for NVIDIA Ampere and newer) timing_cache_path (str): Path to the timing cache if it exists (or) where it will be saved after compilation - enable_cross_platform_compatibility (bool): By default this is False means TensorRT engines can only be executed on the same platform where they were built. + enable_cross_compile_for_windows (bool): By default this is False means TensorRT engines can only be executed on the same platform where they were built. True will enable cross-platform compatibility which allows the engine to be built on one platform and run on another platform """ @@ -109,4 +109,4 @@ class CompilationSettings: hardware_compatible: bool = HARDWARE_COMPATIBLE timing_cache_path: str = TIMING_CACHE_PATH lazy_engine_init: bool = LAZY_ENGINE_INIT - enable_cross_platform_compatibility: bool = ENABLE_CROSS_PLATFORM_COMPATIBILITY + enable_cross_compile_for_windows: bool = ENABLE_CROSS_COMPILE_FOR_WINDOWS diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py index d641fc47b7..3d2aaec6e2 100644 --- a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py +++ b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py @@ -103,8 +103,9 @@ def __init__( self._itensor_to_tensor_meta: Dict[trt.tensorrt.ITensor, TensorMetadata] = ( dict() ) - self.compilation_settings = compilation_settings + self.compilation_settings = compilation_settings + self.validate_compile_settings() # Data types for TRT Module output Tensors self.output_dtypes = ( [dtype._from(o) for o in output_dtypes] if output_dtypes else None @@ -163,6 +164,16 @@ def validate_compile_settings(self) -> None: ): raise RuntimeError("Current platform doesn't support fast native int8!") + if self.compilation_settings.enable_cross_compile_for_windows: + if version.parse(trt.__version__) <= version.parse("10.2"): + raise RuntimeError( + f"Cross compile for windows is not available in the current tensorrt version: {trt.__version__}, it can only be enabled after 10.2.0 post 1" + ) + if platform.system() != "Linux" or platform.architecture()[0] != "64bit": + raise RuntimeError( + f"Cross compile for windows is only supported on AMD 64bit Linux architecture, current platform: {platform.system()=}, {platform.architecture()[0]=}" + ) + if ( dtype.f16 in self.compilation_settings.enabled_precisions and not self.builder.platform_has_fast_fp16 @@ -181,23 +192,12 @@ def _populate_trt_builder_config( builder_config.set_memory_pool_limit( trt.MemoryPoolType.WORKSPACE, self.compilation_settings.workspace_size ) - if version.parse(trt.__version__) >= version.parse("10.2"): - if self.compilation_settings.enable_cross_platform_compatibility: - # currently this flag can only be enabled when building engines on Linux AMD64 platforms - # and target platform for engine execution as Windows AMD64 system. - # https://github.com/NVIDIA/TensorRT/blob/c5b9de37f7ef9034e2efc621c664145c7c12436e/include/NvInfer.h#L8257 - if ( - platform.system() == "Linux" - and platform.architecture()[0] == "64bit" - ): - _LOGGER.info( - f"Setting cross platform compatibility to {self.compilation_settings.enable_cross_platform_compatibility}" - ) - builder_config.runtime_platform = trt.RuntimePlatform.WINDOWS_AMD64 - else: - warnings.warn( - f"{platform.system()=},{platform.architecture()[0]=}, Cross platform compatibility can only be enabled when building engines on Linux AMD64 platform and running the engine on Windows AMD64 platform." - ) + + if self.compilation_settings.enable_cross_compile_for_windows: + builder_config.runtime_platform = trt.RuntimePlatform.WINDOWS_AMD64 + _LOGGER.info( + "Setting runtime_platform as trt.RuntimePlatform.WINDOWS_AMD64" + ) if version.parse(trt.__version__) >= version.parse("8.2"): builder_config.profiling_verbosity = ( diff --git a/py/torch_tensorrt/dynamo/conversion/_conversion.py b/py/torch_tensorrt/dynamo/conversion/_conversion.py index e03c6cf832..61e7028aa3 100644 --- a/py/torch_tensorrt/dynamo/conversion/_conversion.py +++ b/py/torch_tensorrt/dynamo/conversion/_conversion.py @@ -126,17 +126,18 @@ def convert_module( PythonTorchTensorRTModule or TorchTensorRTModule """ interpreter_result = interpret_module_to_result(module, inputs, settings) - # Test fast refit: - from torch_tensorrt.dynamo._refit import _refit_single_trt_engine_with_gm - from torch_tensorrt.logging import TRT_LOGGER - - runtime = trt.Runtime(TRT_LOGGER) - refit_test_engine = runtime.deserialize_cuda_engine( - interpreter_result.serialized_engine - ) weight_name_map: Any = None + # Do the test refit with cached map if make_refitable is enabled if settings.make_refitable: + # Test fast refit: + from torch_tensorrt.dynamo._refit import _refit_single_trt_engine_with_gm + from torch_tensorrt.logging import TRT_LOGGER + + runtime = trt.Runtime(TRT_LOGGER) + refit_test_engine = runtime.deserialize_cuda_engine( + interpreter_result.serialized_engine + ) weight_name_map = interpreter_result.weight_name_map try: _refit_single_trt_engine_with_gm( diff --git a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py index acca0addf6..cbd081bf8b 100644 --- a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py @@ -106,7 +106,11 @@ def __init__( self.engine = None self.weight_name_map = weight_name_map - if self.serialized_engine is not None and not self.settings.lazy_engine_init: + if ( + self.serialized_engine is not None + and not self.settings.lazy_engine_init + and not self.settings.enable_cross_compile_for_windows + ): self.setup_engine() def setup_engine(self) -> None: diff --git a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py index d72fa43262..a8592483f6 100644 --- a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py @@ -126,7 +126,11 @@ def __init__( self.serialized_engine = serialized_engine self.engine = None - if serialized_engine and not self.settings.lazy_engine_init: + if ( + serialized_engine + and not self.settings.lazy_engine_init + and not self.settings.enable_cross_compile_for_windows + ): self.setup_engine() def setup_engine(self) -> None: diff --git a/tests/py/dynamo/runtime/test_000_compilation_settings.py b/tests/py/dynamo/runtime/test_000_compilation_settings.py index 47f700038a..96ce2595ed 100644 --- a/tests/py/dynamo/runtime/test_000_compilation_settings.py +++ b/tests/py/dynamo/runtime/test_000_compilation_settings.py @@ -1,3 +1,9 @@ +import os +import platform +import tempfile +import unittest + +import pytest import torch import torch_tensorrt from torch.testing._internal.common_utils import TestCase, run_tests @@ -6,6 +12,31 @@ class TestEnableTRTFlags(TestCase): + @unittest.skipIf( + platform.system() != "Linux" or platform.architecture()[0] != "64bit", + "Cross compile for windows can only be enabled on linux 64 AMD platform", + ) + @pytest.mark.unit + def test_enable_cross_compile_for_windows(self): + class Add(torch.nn.Module): + def forward(self, a, b): + return torch.add(a, b) + + model = Add().eval().cuda() + inputs = [torch.randn(2, 3).cuda(), torch.randn(2, 3).cuda()] + trt_ep_path = os.path.join(tempfile.gettempdir(), "trt.ep") + compile_spec = { + "inputs": inputs, + "ir": "dynamo", + "min_block_size": 1, + "debug": True, + "enable_cross_compile_for_windows": True, + } + exp_program = torch_tensorrt.dynamo.trace(model, **compile_spec) + trt_gm = torch_tensorrt.dynamo.compile(exp_program, **compile_spec) + torch_tensorrt.save(trt_gm, trt_ep_path, inputs=inputs) + torch._dynamo.reset() + def test_toggle_build_args(self): class AddSoftmax(torch.nn.Module): def forward(self, x): From 697c46b20d21e29d46a6c94f84be89c254c990d4 Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Wed, 18 Sep 2024 11:09:45 -0700 Subject: [PATCH 5/5] test --- .../dynamo/conversion/_TRTInterpreter.py | 4 ---- .../dynamo/conversion/_conversion.py | 24 ------------------- 2 files changed, 28 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py index c085cb3c88..19fdb95f7b 100644 --- a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py +++ b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py @@ -188,10 +188,6 @@ def validate_compile_settings(self) -> None: raise RuntimeError("Current platform doesn't support fast native int8!") if self.compilation_settings.enable_cross_compile_for_windows: - if version.parse(trt.__version__) <= version.parse("10.2"): - raise RuntimeError( - f"Cross compile for windows is not available in the current tensorrt version: {trt.__version__}, it can only be enabled after 10.2.0 post 1" - ) if platform.system() != "Linux" or platform.architecture()[0] != "64bit": raise RuntimeError( f"Cross compile for windows is only supported on AMD 64bit Linux architecture, current platform: {platform.system()=}, {platform.architecture()[0]=}" diff --git a/py/torch_tensorrt/dynamo/conversion/_conversion.py b/py/torch_tensorrt/dynamo/conversion/_conversion.py index fe230b2685..06fade9674 100644 --- a/py/torch_tensorrt/dynamo/conversion/_conversion.py +++ b/py/torch_tensorrt/dynamo/conversion/_conversion.py @@ -138,30 +138,6 @@ def convert_module( Returns: PythonTorchTensorRTModule or TorchTensorRTModule """ - interpreter_result = interpret_module_to_result(module, inputs, settings) - weight_name_map: Any = None - - # Do the test refit with cached map if make_refitable is enabled - if settings.make_refitable: - # Test fast refit: - from torch_tensorrt.dynamo._refit import _refit_single_trt_engine_with_gm - from torch_tensorrt.logging import TRT_LOGGER - - runtime = trt.Runtime(TRT_LOGGER) - refit_test_engine = runtime.deserialize_cuda_engine( - interpreter_result.serialized_engine - ) - weight_name_map = interpreter_result.weight_name_map - try: - _refit_single_trt_engine_with_gm( - new_gm=module, - old_engine=refit_test_engine, - input_list=inputs, - settings=settings, - weight_name_map=interpreter_result.weight_name_map, - ) - except AssertionError: - logger.warning("Fast refit test failed. Removing the weight map caching.") interpreter_result = interpret_module_to_result( module, inputs, settings, engine_cache=engine_cache )