Releases: pytorch/TensorRT
Torch-TensorRT v2.5.0
PyTorch 2.5, CUDA 12.4, TensorRT 10.3, Python 3.12
Torch-TensorRT 2.5.0 targets PyTorch 2.5, TensorRT 10.3 and CUDA 12.4.
(builds for CUDA 11.8/12.1 are available via the PyTorch package index - https://download.pytorch.org/whl/cu118 https://download.pytorch.org/whl/cu121)
Deprecation notice
The torchscript frontend will be deprecated in v2.6. Specifically, the following usage will no longer be supported and will issue a deprecation warning at runtime if used:
torch_tensorrt.compile(model, ir="torchscript")
Moving forward, we encourage users to transition to one of the supported options:
torch_tensorrt.compile(model)
torch_tensorrt.compile(model, ir="dynamo")
torch.compile(model, backend="tensorrt")
Torchscript will continued to be supported as a deployment format via post compilation tracing
dynamo_model = torch_tensorrt.compile(model, ir="dynamo", arg_inputs=[...])
ts_model = torch.jit.trace(dynamo_model, inputs=[...])
ts_model(...)
Please refer to the README for more information regarding our deprecation policy.
Refit (Beta)
v2.5.0 introduces direct model refitting from PyTorch for your compiled Torch-TensorRT programs. Sometimes the weights need to change through the course of inference and in the past full recompilation was necessary to change out the weights of the model, either through automatic recompilation through torch.compile
or through manual recompilation with torch_tensorrt.compile
. Now using the refit_module_weights
API, compiled modules can be refitted by providing a new PyTorch module (with identical structure) containing the new weights. Compiled modules must be compiled with make_refittable
to use this feature.
# Create and export the updated model
model2 = models.resnet18(pretrained=True).eval().to("cuda")
exp_program2 = torch.export.export(model2, tuple(inputs))
compiled_trt_ep = torch_trt.load("./compiled.ep")
# This returns a new module with updated weights
new_trt_gm = refit_module_weights(
compiled_module=compiled_trt_ep,
new_weight_module=exp_program2,
)
There are some ops that are not compatible with refit, such as ops that utilize ILoop layer
. When make_refittable
is enabled, these ops will be forced to run in PyTorch. It should also be known that engines that are refit enabled may be slightly less performant than non-refittable engines as TensorRT cannot tune for the specific weights it will see at execution time.
Refit Caching (Experimental)
Refitting on its own can help to speed up update model swap times by 0.5-2x. However, the speed of refit can be further improved by utilizing refit caching. Refit caching at compile time stores hints for a direct mapping from PyTorch module members to TRT layer names in the metadata of TorchTensorRTModule
. This caching can speed up refit by orders of magnitude. However, it currently has limitations when dealing with layers that have compile time optimization. This feature is still experimental as there may be some ops that are not amenable to refit caching. We still enable using the cache by default when refitting to collect feedback on the edge cases and we provide a output validator which can be used to ensure that refit occurred properly. When verify_outputs
is True if the refit failed, then the refitter will discard the cache and refit from scratch.
new_trt_gm = refit_module_weights(
compiled_module=compiled_trt_ep,
new_weight_module=exp_program2,
arg_inputs=inputs,
verify_outputs=True,
)
MutableTorchTensorRTModule (Experimental)
torch.compile
is incredibly useful when it comes to trying to optimize models that may change over time since it can automatically recompile the module when something changes. However, the major limitation of torch.compile
is it cannot be serialized. For users who are looking for similar flexibility but the added ability to serialize and move their work we have introduced the MutableTorchTensorRTModule
. This module wraps a PyTorch module and exposes its members transparently, however it injects listeners on setattr
and overrides the forward function to use TensorRT accelerated subgraphs. This means you can make changes to your module such as applying adapters and the MutableTorchTensorRTModule
will detect the change and mark the function for refit or recompilation based on the change. Similar to torch.compile
this is done in a JIT manner, so the first inference after a change will perform the refit or recompile operation.
from diffusers import DiffusionPipeline
with torch.no_grad():
settings = {
"use_python_runtime": True,
"enabled_precisions": {torch.float16},
"debug": True,
"make_refittable": True,
}
model_id = "runwayml/stable-diffusion-v1-5"
device = "cuda:0"
prompt = "house in forest, shuimobysim, wuchangshuo, best quality"
negative = "(worst quality:2), (low quality:2), (normal quality:2), lowres, normal quality, out of focus, cloudy, (watermark:2),"
pipe = DiffusionPipeline.from_pretrained(
model_id, revision="fp16", torch_dtype=torch.float16
)
pipe.to(device)
# The only extra line you need
pipe.unet = torch_trt.MutableTorchTensorRTModule(pipe.unet, **settings)
image = pipe(prompt, negative_prompt=negative, num_inference_steps=30).images[0]
image.save("./without_LoRA_mutable.jpg")
# Standard Huggingface LoRA loading procedure
pipe.load_lora_weights(
"stablediffusionapi/load_lora_embeddings",
weight_name="moxin.safetensors",
adapter_name="lora1",
)
pipe.set_adapters(["lora1"], adapter_weights=[1])
pipe.fuse_lora()
pipe.unload_lora_weights()
# Refit triggered
image = pipe(prompt, negative_prompt=negative, num_inference_steps=30).images[0]
image.save("./with_LoRA_mutable.jpg")
Engine Caching
In some scenarios, users may compile a module multiple times and each time it takes a long time to build a TensorRT engine in the backend. Engine caching will boost performance by reusing previously compiled TensorRT engines rather than recompiling it every time, thereby avoiding recompilation time. When a cached engine is loaded, it will be refitted with the new module weights.
To make it more efficient, as long as two graph modules have the same structure, even though their weights are not the same, we still consider they are the same, i.e., isomorphic graph modules. Isomorphic graph modules with the same compilation settings will share cached engines.
We implemented DiskEngineCache
so that users can directly use the APIs to control how and where to save/load cached engines on the disk of the local machine. For exmaple,
trt_gm = torch_trt.dynamo.compile(
exp_program,
tuple(inputs),
make_refitable=True,
cache_built_engines=True,
reuse_cached_engines=True,
engine_cache_dir="/tmp/torch_trt_engine_cache"
engine_cache_size=1 << 30, # 1GB
)
In addition, considering some users want to save to or load engines from other servers, clusters, or cloud, we also provided a base class BaseEngineCache
so that users are able to easily implement their own logic to save and load engines. For example,
class MyEngineCache(BaseEngineCache):
def __init__(
self,
addr: str,
) -> None:
self.addr= addr
def save(
self,
hash: str,
blob: bytes,
prefix: str = "blob",
):
# user's customized function to save engines
write_to(self.addr, name=f"{prefix}_{hash}.bin", content=blob)
def load(self, hash: str, prefix: str = "blob") -> Optional[bytes]:
# user's customized function to load engines
return read_from(self.addr, name=f"{prefix}_{hash}.bin")
trt_gm = torch_trt.dynamo.compile(
exp_program,
tuple(inputs),
make_refitable=True,
cache_built_engines=True,
reuse_cached_engines=True,
custom_engine_cache=MyEngineCache("xxxxx"),
)
CUDA Graphs
In v2.5.0 CUDA graph support for in engine kernel launch optimization has been added through a new runtime mode. This mode can be activated from Python using
import torch_tensorrt
my_torchtrt_model = torch_tensorrt.compile(...)
with torch_tensorrt.runtime.enable_cudagraphs():
my_torchtrt_model(inputs)
This mode works by creating CUDAGraphs around individual TensorRT engines which improves their efficiency. It creates graph through a capture phase which is tied to the input shape to the engine. When the input shape changes, this graph is invalidated and the graph is automatically recaptured.
Model Optimizer based Int8 Quantization(PTQ) support for Linux
This version introduces official support for the int8 Quantization via modelopt (https://github.com/NVIDIA/TensorRT-Model-Optimizer) 17.0 for Linux.
Full examples can be found at https://github.com/pytorch/TensorRT/blob/main/examples/dynamo/vgg16_ptq.py
running the vgg16 example for int8 ptq
step1: generate checkpoint file for vgg16:
cd examples/int8/training/vgg16
python main.py --lr 0.01 --batch-size 128 --drop-ratio 0.15 \
--ckpt-dir $(pwd)/vgg16_ckpts --epochs 20 --seed 545
this should produce a ckpt file at examples/int8/training/vgg16/vgg16_ckpts/ckpt_epoch20.pth
step2: run int8 ptq for vgg16:
python examples/dynamo/vgg16_fp8_ptq.py --batch-size 128 \
--ckpt=examples/int8/training/vgg16/vgg16_ckpts/ckpt_epoch20.pth \
--quantize-type=int8
LLM examples
We now offer dynamic shape support for all converters (covering core ATen operations). Dynamic shapes are widely utilized in leading LLM models, where input sequence lengths may vary. With this release, we showcase full graph compilation for Ll...
Torch-TensorRT v2.4.0
C++ runtime support in Windows Support, Enhanced Dynamic Shape support in Converters, PyTorch 2.4, CUDA 12.4, TensorRT 10.1, Python 3.12
Torch-TensorRT 2.4.0 targets PyTorch 2.4, CUDA 12.4 (builds for CUDA 11.8/12.1 are available via the PyTorch package index - https://download.pytorch.org/whl/cu118 https://download.pytorch.org/whl/cu121) and TensorRT 10.1.
This version introduces official support for the C++ runtime on the Windows platform, though it is limited to the dynamo frontend, supporting both AOT and JIT workflows. Users can now utilize both Python and C++ runtimes on Windows. Additionally, this release expands support to include all Aten Core Operators, except torch.nonzero
, and significantly increases dynamic shape support across more converters. Python 3.12 is supported for the first time in this release.
Full Windows Support
In this release we introduce both C++ and Python runtime support in Windows. Users can now directly optimize PyTorch models with TensorRT on Windows, with no code changes. C++ runtime is the default option and users can enable Python runtime by specifying use_python_runtime=True
import torch
import torch_tensorrt
import torchvision.models as models
model = models.resnet18(pretrained=True).eval().to("cuda")
input = torch.randn((1, 3, 224, 224)).to("cuda")
trt_mod = torch_tensorrt.compile(model, ir="dynamo", inputs=[input])
trt_mod(input)
Enhanced Op support in Converters
Support for Converters is near 100% of core ATen. At this point fall back to PyTorch execution is either due to specific limitations of converters or some combination of user compiler settings (e.g. torch_executed_ops
, dynamic shape). This release also expands the number of operators that support dynamic shape. dryrun
will provide specific information on your model + settings support.
What's Changed
- fix: FakeTensors appearing in
get_attr
calls by @gs-olive in #2669 - feat: support adaptive_avg_pool1d dynamo converter by @zewenli98 in #2614
- fix: Add cmake missing source file ref for core_lowering.passes by @Arktische in #2672
- ci: Torch nightly version upgrade to
2.4.0
by @gs-olive in #2704 - Add support for
aten.pixel_unshuffle
dynamo converter by @HolyWu in #2696 - feat: support aten.atan2 converter by @chohk88 in #2689
- feat: support aten.index_select converter by @chohk88 in #2710
- feat: support aten.isnan converter by @chohk88 in #2711
- feat: support adaptive avg pool 2d and 3d dynamo converters by @zewenli98 in #2632
- feat: support aten.expm1 converter by @chohk88 in #2714
- fix: Add dependencies to Docker container for
apt
versioning TRT by @gs-olive in #2746 - fix: Missing parameters in compiler settings by @gs-olive in #2749
- fix: param bug in
test_binary_ops_aten
by @zewenli98 in #2733 - aten::empty_like by @apbose in #2654
- empty_permute decomposition by @apbose in #2698
- Removing grid lowering by @apbose in #2686
- Selectively enable different frontends by @narendasan in #2693
- chore(deps): bump transformers from 4.33.2 to 4.36.0 in /tools/perf by @dependabot in #2555
- Fix upsample converter not properly registered by @HolyWu in #2683
- feat: TS Add converter support for aten::grid_sampler by @mfeliz-cruise in #2717
- fix: Bump
torchvision
version by @gs-olive in #2770 - fix: convert_module_to_trt_engine by @zewenli98 in #2728
- chore: cherry pick of save API by @peri044 in #2719
- chore: Upgrade TensorRT version to TRT 10 EA (#2699) by @peri044 in #2774
- Fix minor grammatical corrections by @aakashapoorv in #2779
- feat: cherry-pick of Implement symbolic shape propagation, sym_size converter by @peri044 in #2751
- feat: cherry-pick of torch.compile dynamic shapes by @peri044 in #2750
- chore: bump deps for default workspace file by @narendasan in #2786
- fix: Point infra branch to main by @gs-olive in #2785
- "empty_like" decomposition test correction by @apbose in #2784
- chore: Bump versions by @narendasan in #2787
- fix: refactor layer norm converter with INormalization Layer by @zewenli98 in #2755
- TRT-10 GA Support for main branch by @zewenli98 in #2781
- chore(//tests): Update tests to use assertEqual by @narendasan in #2800
- feat: Add support for
is_causal
argument in attention by @gs-olive in #2780 - feat: Adding support for native int64 by @narendasan in #2789
- chore: small mypy issue by @narendasan in #2803
- Rand converter - evaluator by @apbose in #2580
- cherry-pick: Python Runtime Windows Builds on TRT 10 (#2764) by @gs-olive in #2776
- feat: support 1d ITensor offsets for embedding_bag converter by @zewenli98 in #2677
- chore(deps): bump transformers from 4.36.0 to 4.38.0 in /tools/perf by @dependabot in #2766
- fix: a bug in func run_test_compare_tensor_attributes_only by @zewenli98 in #2809
- Fix ModuleNotFoundError in ptq by @HolyWu in #2814
- docs: Example on how to use custom kernels in Torch-TensorRT by @narendasan in #2812
- typo fix in doc on saving models by @laikhtewari in #2818
- chore: Remove CUDNN dependencies by @zewenli98 in #2804
- fix: bug in elementwise base for static inputs by @zewenli98 in #2819
- Use environment for docgen by @atalman in #2826
- tool: Opset coverage notebook by @narendasan in #2831
- ci: Add release flag for nightly build tag by @gs-olive in #2821
- [doc] Update options documentation for torch.compile by @lanluo-nvidia in #2834
- feat(//py/torch_tensorrt/dynamo): Support for BF16 by @narendasan in #2833
- feat: data parallel inference examples by @bowang007 in #2805
- fix: bugs in TRT 10 upgrade by @zewenli98 in #2832
- feat: support aten._cdist_forward converter by @chohk88 in #2726
- chore: cherry pick of #2805 by @bowang007 in #2851
- feat: Add support for multi-device safe mode in C++ by @gs-olive in #2824
- feat: support aten.log1p converter by @chohk88 in #2823
- feat: support aten.as_strided converter by @chohk88 in #2735
- fix: Fix deconv kernel channel num_output_maps where wts are ITensor by @andi4191 in #2678
- Aten scatter converter by @apbose in #2664
- fix user_guide and tutorial docs by @yoosful in #2854
- chore: Make from and to methods use the same TRT API by @narendasan in #2858
- add aten.topk implementation by @lanluo-nvidia in #2841
- feat: support aten.atan2.out converter by @chohk88 in #2829
- chore: update docker, refactor CI TRT dep to main by @peri044 in #2793
- feat: Cherry pick of Add validators for dynamic shapes in converter registration by @peri044 in #2849
- feat: support aten.diagonal converter by @chohk88 in #2856
- Remove ops from decompositions where converters exist by @HolyWu in #2681
- slice_scatter decomposition by @apbose in #2519
- select_scatter decomp by @apbose in #2515
- manylinux wheel file build update for TensorRT-10.0.1 by @lanluo-nvidia in #2868
- replace itemset due to numpy version 2.0 removed itemset api by @lanluo-nvidia in #2879
- chore: cherry-pick of DS feature by @peri044 in #2857
- feat: TS Add converter supp...
Torch-TensorRT v2.3.0
Windows Support, Dynamic Shape and Quantization in Dynamo , PyTorch 2.3, CUDA 12.1, TensorRT 10.0
Torch-TensorRT 2.3.0 targets PyTorch 2.3, CUDA 12.1 (builds for CUDA 11.8 are available via the PyTorch package index - https://download.pytorch.org/whl/cu118) and TensorRT 10.0. 2.3.0 adds official support for Windows as a platform. Windows will only support using the Dynamo frontend and currently users are required to use the Python-only runtime (support for the C++ runtime will be added in a future version). This release also adds support for Dynamic shape without recompilation. Users can also now use quantized models with Torch-TensorRT using the Model Optimizer toolkit (https://github.com/NVIDIA/TensorRT-Model-Optimizer).
Note: Python 3.12 is not supported as the Dynamo stack in PyTorch 2.3.0 does not support Python 3.12
Windows
In this release we introduce Windows support for the Python runtime using the Dynamo paths. Users can now directly optimize PyTorch models with TensorRT on Windows, with minimal code changes. This integration enables Python-only optimization in the Torch-TensorRT Dynamo compilation paths (ir="dynamo"
and ir="torch_compile"
).
import torch
import torch_tensorrt
import torchvision.models as models
model = models.resnet18(pretrained=True).eval().to("cuda")
input = torch.randn((1, 3, 224, 224)).to("cuda")
trt_mod = torch_tensorrt.compile(model, ir="dynamo", inputs=[input])
trt_mod(input)
Dynamic Shaped Model Compilation in Dynamo
Dynamic shape support has become more robust in v2.3.0. Torch-TensorRT now leverages symbolic information in the graph to calculate intermediate shape ranges which allows more dynamic shape cases to be supported. For AOT workflows using torch.export, using these new features requires no changes. For JIT workflows which previously used torch.compile
guards to automatically recompile the engines where the input size changes, users can now mark dynamic dimensions using torch APIs (https://pytorch.org/docs/stable/torch.compiler_dynamic_shapes.html). Using these APIs will mean that as long as inputs do not violate the specified constraints, engines would not recompile.
AOT workflow
import torch
import torch_tensorrt
compile_spec = {"inputs": [torch_tensorrt.Input(min_shape=(1, 3, 224, 224),
opt_shape=(4, 3, 224, 224),
max_shape=(8, 3, 224, 224),
dtype=torch.float32)],
"enabled_precisions": {torch.float}}
trt_model = torch_tensorrt.compile(model, **compile_spec)
JIT workflow
import torch
import torch_tensorrt
compile_spec = {"enabled_precisions": {torch.float}}
inputs = torch.randn((4, 3, 224, 224)).to("cuda")
# This indicates the dimension 0 is dynamic and the range is [1, 8]
torch._dynamo.mark_dynamic(inputs, 0, min=1, max=8)
trt_model = torch.compile(model, backend="tensorrt", options=compile_spec)
More information can be found here: https://pytorch.org/TensorRT/user_guide/dynamic_shapes.html
Explicit Dynamic Shape support in Converters
Converters now explicitly declare their support for dynamic shapes and we are progressively adding and verifying. Converter writers can specify the support for dynamic shapes using the supports_dynamic_shape
argument of the dynamo_tensorrt_converter
decorator.
@dynamo_tensorrt_converter(
torch.ops.aten.convolution.default,
capability_validator=lambda conv_node: conv_node.args[7] in ([0], [0, 0], [0, 0, 0])
supports_dynamic_shapes=True,
) # type: ignore[misc]
def aten_ops_convolution(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
By default, if a converter has not been marked as supporting dynamic shape, it's operator will be run in PyTorch if the user has specified the inputs as dynamic. This is done for the sake of ensuring that compilation will succeed with some valid compiled module. However, many operators already support dynamic shape in an untested fashion. Therefore, users can decide to enable to full converter library for dynamic shape using the assume_dynamic_shape_support
flag. This flag assumes all converters support dynamic shape, leading to more operations being run in TensorRT with the potential drawback that some ops may cause compilation or runtime failures. Future releases will add progressively add coverage for dynamic shape for all Core ATen Operators.
Quantization in Dynamo
We introduce support for model quantization in FP8. We support models quantized using NVIDIA TensorRT-Model-Optimizer toolkit. This toolkit introduces quantization nodes in the graph which are converted and used by TensorRT to quantize the model into lower precision. Although the toolkit supports quantization in various datatypes, we only support FP8 in this release.
Please refer to our end-end example Torch Compile VGG16 with FP8 and PTQ on how to use this.
Engine Version and Hardware Compatibility
We introduce new compilation arguments, hardware_compatible: bool
and version_compatible: bool
, which enable two key features in TensorRT.
hardware_compatible
Enabling hardware compatibility mode will generate TRT Engines which are compatible with Ampere and newer GPUs. As a result, engines built on one GPU can later be run on others, without requiring recompilation.
version_compatible
Enabling version compatibility mode will generate TRT Engines which are compatible with newer versions of TensorRT. As a result, engines built with one version of TensorRT will be forward compatible with other TRT versions, without needing recompilation.
...
trt_mod = torch_tensorrt.compile(model, ir="dynamo", inputs=[input], hardware_compatible=True, version_compatible=True)
...
New Data Type Support
Torch-TensorRT includes a number of new data types that leverage dedicated hardware on Ampere, Hopper and future architectures.
bfloat16
has been added as a supported type alongside FP16 and FP32 that can be enabled for additional kernel tactic options. Models that contain BF16 weights can now be provided to Torch-TensorRT without modification. FP8 has been added with support for Hopper and newer architectures as a new quantization format (see below), similar to INT8. Finally, native support for INT64 inputs and computation has been added. In the past, the truncate_long_and_double
feature flag must be enabled in order to handle INT64 and FLOAT64 computation, inputs and weights. This flag would cause the compiler to truncate any INT64 or FLOAT64 objects to INT32 and FLOAT32 respectively. Now INT64 objects will not be truncated and remain in INT64. As such, the truncate_long_and_double
flag has been renamed truncate_double
as FLOAT64 truncation is still required, truncate_long_and_double
is now deprecated.
What's Changed
- feat: support group_norm, batch_norm, and layer_norm by @zewenli98 in #2330
- support argmax converter by @bowang007 in #2291
- feat: Decomposition for
_unsafe_index
by @gs-olive in #2386 - docs: Add documentation of
torch.compile
backend usage by @gs-olive in #2363 - fix: Remove supported ops from decompositions by @gs-olive in #2390
- fix: Converter, inputs, and utils bugfixes for Transformer XL by @gs-olive in #2404
- feat: support embedding_bag converter (1D input) by @zewenli98 in #2395
- feat: support chunk dynamo converter by @zewenli98 in #2401
- chore: Add documentation for dynamo.compile backend by @peri044 in #2389
- Support new FX Legacy Registry in opset coverage tool by @laikhtewari in #2366
- fix: type error in embedding_bag by @zewenli98 in #2418
- feat: support cumsum dynamo converter by @zewenli98 in #2403
- 2.0 docs overhaul by @narendasan in #2420
- feat: support tile dynamo converter by @zewenli98 in #2402
- chore: update perf tooling to add dynamo options by @peri044 in #2423
- feat: Add
aten.unbind
decomposition for VIT by @gs-olive in #2430 - fix: Segfault fix for Benchmarks by @gs-olive in #2432
- examples: Stable Diffusion
torch.compile
sample with output image by @gs-olive in #2417 - minor fix: Parse out slashes in Docker container name by @gs-olive in #2437
- fix: Docs rendering on PyTorch site by @gs-olive in #2440
- Numpy changes for aten::index converter by @apbose in #2396
- feat: a lowering pass to re-compose ops into aten.linear by @zewenli98 in #2411
- chore: fix docs for export by @peri044 in #2447
- chore: add additional BN native converter by @peri044 in #2446
- minor fix: Update Benchmark values by @gs-olive in #2453
- Dele...
Torch-TensorRT v2.2.0
Dynamo Frontend for Torch-TensorRT, PyTorch 2.2, CUDA 12.1, TensorRT 8.6
Torch-TensorRT 2.2.0 targets PyTorch 2.2, CUDA 12.1 (builds for CUDA 11.8 are available via the PyTorch package index - https://download.pytorch.org/whl/cu118) and TensorRT 8.6. This release is the second major release of Torch-TensorRT as the default frontend has changed from TorchScript to Dynamo allowing for users to more easily control and customize the compiler in Python.
The dynamo frontend can support both JIT workflows through torch.compile
and AOT workflows through torch.export + torch_tensorrt.compile
. It targets the Core ATen Opset (https://pytorch.org/docs/stable/torch.compiler_ir.html#core-aten-ir) and currently has 82% coverage. Just like in Torchscript graphs will be partitioned based on the ability to map operators to TensorRT in addition to any graph surgery done in Dynamo.
Output Format
Through the Dynamo frontend, different output formats can be selected for AOT workflows via the output_format
kwarg. The choices are torchscript
where the resulting compiled module will be traced with torch.jit.trace
, suitable for Pythonless deployments, exported_program
a new serializable format for PyTorch models or finally if you would like to run further graph transformations on the resultant model, graph_module
will return a torch.fx.GraphModule
.
Multi-GPU Safety
To address a long standing source of overhead, single GPU systems will now operate without typical required device checks. This check can be re-added when multiple GPUs are available to the host process using torch_tensorrt.runtime.set_multi_device_safe_mode
# Enables Multi Device Safe Mode
torch_tensorrt.runtime.set_multi_device_safe_mode(True)
# Disables Multi Device Safe Mode [Default Behavior]
torch_tensorrt.runtime.set_multi_device_safe_mode(False)
# Enables Multi Device Safe Mode, then resets the safe mode to its prior setting
with torch_tensorrt.runtime.set_multi_device_safe_mode(True):
...
More information can be found here: https://pytorch.org/TensorRT/user_guide/runtime.html
Capability Validators
In the Dynamo frontend, tests can be written and associated with converters to dynamically enable or disable them based on conditions in the target graph.
For example, the convolution converter in dynamo only supports 1D, 2D, and 3D convolution. We can therefore create a lambda which given a convolution FX node can determine if the convolution is supported:
@dynamo_tensorrt_converter(
torch.ops.aten.convolution.default,
capability_validator=lambda conv_node: conv_node.args[7] in ([0], [0, 0], [0, 0, 0])
) # type: ignore[misc]
def aten_ops_convolution(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
In such a case where the Node
is not supported, the node will be partitioned out and run in PyTorch.
All capability validators are run prior to partitioning, after the lowering phase.
More information on writing converters for the Dynamo frontend can be found here: https://pytorch.org/TensorRT/contributors/dynamo_converters.html
Breaking Changes
- Dynamo (torch.export) is now the default frontend for Torch-TensorRT. The TorchScript and FX frontends are now in maintenance mode. Therefore any
torch.nn.Module
s ortorch.fx.GraphModule
s provided totorch_tensorrt.compile
will by default be exported usingtorch.export
then compiled. This default can be overridden by setting their=[torchscript|fx]
kwarg. Any bugs reported will first be attempted to be resolved in the dynamo stack before attempting other frontends however pull requests for additional functionally in the TorchScript and FX frontends from the community will still be accepted.
What's Changed
- chore: Update Torch and Torch-TRT versions and docs on
main
by @gs-olive in #1784 - fix: Repair invalid schema arising from lowering pass by @gs-olive in #1786
- fix: Allow full model compilation with collection inputs (
input_signature
) by @gs-olive in #1656 - feat(//core/conversion): Add support for aten::size with dynamic shaped models for Torchscript backend. by @peri044 in #1647
- feat: add support for aten::baddbmm by @mfeliz-cruise in #1806
- [feat] Add dynamic conversion path to aten::mul evaluator by @mfeliz-cruise in #1710
- [fix] aten::stack with dynamic inputs by @mfeliz-cruise in #1804
- fix undefined attr issue by @bowang007 in #1783
- fix: Out-Of-Bounds bug in Unsqueeze by @gs-olive in #1820
- feat: Upgrade Docker build to use custom TRT + CUDNN by @gs-olive in #1805
- fix: include str ivalue type conversion by @bowang007 in #1785
- fix: dependency order of inserted long input casts by @mfeliz-cruise in #1833
- feat: Add ts converter support for aten::all.dim by @mfeliz-cruise in #1840
- fix: Error caused by invalid binding name in
TRTEngine.to_str()
method by @gs-olive in #1846 - fix: Implement
aten.mean.default
andaten.mean.dim
converters by @gs-olive in #1810 - feat: Add converter for aten::log2 by @mfeliz-cruise in #1866
- feat: Add support for aten::where with scalar other by @mfeliz-cruise in #1855
- feat: Add converter support for logical_and by @mfeliz-cruise in #1856
- feat: Refactor FX APIs under dynamo namespace for parity with TS APIs by @peri044 in #1807
- fix: Add version checking for
torch._dynamo
import in__init__
by @gs-olive in #1881 - fix: Improve Docker build robustness, add validation by @gs-olive in #1873
- fix: Improve input weight handling to
acc_ops
convolution layers in FX by @gs-olive in #1886 - fix: Upgrade
main
to TRT 8.6, CUDA 11.8, CuDNN 8.8, Torch Dev by @gs-olive in #1852 - feat: Wrap dynamic size handling in a compilation flag by @peri044 in #1851
- fix: Add torchvision legacy CI parameter by @gs-olive in #1918
- Sync fb internal change to OSS by @wushirong in #1892
- fix: Reorganize Dynamo directory + backends by @gs-olive in #1928
- fix: Improve partitioning + lowering systems in
torch.compile
path by @gs-olive in #1879 - fix: Upgrade TRT to 8.6.1, parallelize FX tests in CI by @gs-olive in #1930
- feat: Add issue template for Story by @gs-olive in #1936
- feat: support type promotion in aten::cat converter by @mfeliz-cruise in #1911
- Reorg for converters in (FX Converter Refactor [1/N]) by @narendasan in #1867
- fix: Add support for default dimension in
aten.cat
by @gs-olive in #1863 - Relaxing glob pattern for CUDA12 by @borisfom in #1950
- refactor: Centralizing sigmoid implementation (FX Converter Refactor [2/N]) <Target: converter_reorg_proto> by @narendasan in #1868
- fix: Address
.numpy()
issue on fake tensors by @gs-olive in #1949 - feat: Add support for passing through build issues in Dynamo compile by @gs-olive in #1952
- fix: int/int=float division by @mfeliz-cruise in #1957
- fix: Support dims < -1 in aten::stack converter by @mfeliz-cruise in #1947
- fix: Resolve issue in isInputDynamic with mixed static/dynamic shapes by @mfeliz-cruise in #1883
- DLFW changes by @apbose in #1878
- feat: Add converter for aten::isfinite by @mfeliz-cruise in #1841
- Reorg for converters in hardtanh(FX Converter Refactor [5/N]) <Target: converter_reorg_proto> by @apbose in #1901
- fix/feat: Add lowering pass to resolve most
aten::Int.Tensor
uses by @gs-olive in #1937 - fix: Add decomposition for
aten.addmm
by @gs-olive in #1953 - Reorg for converters tanh (FX Converter Refactor [4/N]) <Target: converter_reorg_proto> by @apbose in #1900
- Reorg for converters leaky_relu (FX Converter Refactor [6/N]) <Target: converter_reorg_proto> by @apbose in #1902
- Upstream 3 features to fx_ts_compat: MS, VC, Optimization Level by @wu6u3tw in #1935
- fix: Add lowering pass to remove output repacking in
convert_method_to_trt_engine
calls by @gs-olive in #1945 - Fixing aten::slice invalid schema and i...
Torch-TensorRT v1.4.0
PyTorch 2.0, CUDA 11.8, TensorRT 8.6, Support for the new torch.compile
API, compatibility mode for FX frontend
Torch-TensorRT 1.4.0 targets PyTorch 2.0, CUDA 11.8, TensorRT 8.5. This release introduces a number of beta features to set the stage for working with PyTorch and TensorRT in the 2.0 ecosystem. Primarily, this includes a new torch.compile
backend targeting Torch-TensorRT. It also adds a compatibility layer that allows users of the TorchScript frontend for Torch-TensorRT to seamlessly try FX and Dynamo.
torch.compile` Backend for Torch-TensorRT
One of the most prominent new features in PyTorch 2.0 is the torch.compile
workflow, which enables users to accelerate code easily by specifying a backend of their choice. Torch-TensorRT 1.4.0 introduces a new backend for torch.compile
as a beta feature, including a convenience frontend to perform accelerated inference. This frontend can be accessed in one of two ways:
import torch_tensorrt
torch_tensorrt.dynamo.compile(model, inputs, ...)
##### OR #####
torch_tensorrt.compile(model, ir="dynamo_compile", inputs=inputs, ...)
For more examples, see the provided sample scripts, which can be found here
This compilation method has a couple key considerations:
- It can handle models with data-dependent control flow
- It automatically falls back to Torch if the TRT Engine Build fails for any reason
- It uses the Torch FX
aten
library of converters to accelerate models - Recompilation can be caused by changing the batch size of the input, or providing an input which enters a new control flow branch
- Compiled models cannot be saved across Python sessions (yet)
The feature is currently in beta, and we expect updates, changes, and improvements to the above in the future.
fx_ts_compat
Frontend
As the ecosystem transitions from TorchScript to Dynamo, users of Torch-TensorRT may want start to experiment with this stack. As such we have introduced a new frontend for Torch-TensorRT which exposes the same APIs as the TorchScript frontend but will use the FX/Dynamo compiler stack. You can try this frontend by using the ir="fx_ts_compat"
setting
torch_tensorrt.compile(..., ir="fx_ts_compat")
What's Changed
- Fix build by @yinghai in #1479
- add circle CI signal in README page by @yinghai in #1481
- fix eisum signature by @yinghai in #1480
- Fix link to CircleCI in README.md by @yinghai in #1483
- Minor changes by @yinghai in #1482
- [FX] Changes done internally at Facebook by @frank-wei in #1456
- chore: upload docs for 1.3.0 by @narendasan in #1504
- fix: Repair Citrinet-1024 compilation issues by @gs-olive in #1488
- refactor: Split elementwise tests by @peri044 in #1507
- [feat] Support 1D topk by @mfeliz-cruise in #1491
- Support aten::sum with bool tensor input by @mfeliz-cruise in #1512
- [fix]Disambiguate cast layer names by @mfeliz-cruise in #1513
- feat: Add functionality for easily benchmarking fx code on key models by @gs-olive in #1506
- [feat]Canonicalize aten::multiply to aten::mul by @mfeliz-cruise in #1517
- broadcast the two input shapes for transposed matmul by @nvpohanh in #1457
- make padding layer converter more efficient by @nvpohanh in #1470
- fix: Change equals-check from reference to value for BERT model not compiling in FX by @gs-olive in #1539
- Update README dependencies section for v1.3.0 by @take-cheeze in #1540
- fix:
aten::where
with differing-shape inputs bugfix by @gs-olive in #1533 - fix: Automatically send truncated long ints to cuda at shape analysis time by @gs-olive in #1541
- feat: Add functionality to FX benchmarking + Improve documentation by @gs-olive in #1529
- [fix] Fix crash when calling unbind on evaluated tensor by @mfeliz-cruise in #1554
- Update test_flatten_aten and test_reshape_aten due to PT2.0 changed tracer behavior for these ops by @frank-wei in #1559
- fix: Bugfix for
align_corners=False
- FX interpolate by @gs-olive in #1561 - fix: Properly cast intermediate Int8 tensors to TensorRT Engines in Fallback by @gs-olive in #1549
- Upgrade stack to Pytorch 2.0 + CUDA 11.7 + TRT 8.5 GA by @peri044 in #1477
- feat: Add option to specify int64 as an Input dtype by @gs-olive in #1551
- feat: Support int inputs to aten::max/min and aten::argmax/argmin by @mfeliz-cruise in #1574
- fix: Add
aten::full_like
evaluator by @gs-olive in #1584 - tools: assign 1 person to a bug instead of all by @narendasan in #1604
- feat: Add support for aten::meshgrid by @mfeliz-cruise in #1601
- [FX] Changes done internally at Facebook by @frank-wei in #1603
- chore: Add FX core test by @peri044 in #1593
- chore: Update dockerfile by @peri044 in #1581
- fix: Replace
RemoveDropout
lowering pass implementation with modified JIT pass by @gs-olive in #1589 - [FX] Changes done internally at Facebook by @frank-wei in #1625
- chore: Update Dockerfile to Ubuntu 20.04 + Crash Resolution by @gs-olive in #1639
- fix: Bugfix in Linear-to-AddMM Fusion Lowering Pass by @gs-olive in #1619
- fix: Resolve compilation bug for empty tensors in
aten::select
by @gs-olive in #1623 - Convolution cast by @apbose in #1609
- fix: Bugfix in TRT Engine deserialization indexing by @gs-olive in #1646
- fix: fix the inappropriate lowering pass of aten::to by @bowang007 in #1649
- Lowering aten::pad to aten::constant_pad_nd/aten::reflection_padXd/aten::replication_padXd by @ruoqianguo in #1588
- [fix] Disambiguate element-wise cast layer names by @mfeliz-cruise in #1630
- feat: Add optional tensor domain argument to Input class by @gs-olive in #1537
- Improve batch_norm fp16 accuracy by @mfeliz-cruise in #1450
- add an example of aten2trt, fix batch norm pass by @frank-wei in #1685
- fix: Issue in non-Tensor Input Resolution by @gs-olive in #1617
- Corrected a typo, which was raising an error by @zshn25 in #1694
- Cherry-pick manylinux compatible builds into main by @narendasan in #1677
- fix: Improve input handling for
input_signature
by @gs-olive in #1698 - Unsqueeze operator with dynamic inout by @apbose in #1624
- [feat] Add converter support for index_select by @mfeliz-cruise in #1692
- [feat] Add converter support for aten::logical_not by @mfeliz-cruise in #1705
- fix: Bugfix in convNd_to_convolution lowering pass by @gs-olive in #1693
- [feat] Add converter for aten::any.dim by @mfeliz-cruise in #1707
- [fix] resolve issue for single non-batch index tensor in aten::index by @mfeliz-cruise in #1700
- fix: Handle nonetype pad value for Constant pad by @peri044 in #1712
- infra: Add Torch 1.13.1 testing to nightly CI by @gs-olive in #1731
- fix: Allow full model compilation with collection outputs by @gs-olive in #1599
- fix: fix the prim::Loop fallback issue by @bowang007 in #1691
- feat: Add decorator utility to improve error messaging for legacy support by @gs-olive in #1738
- minor fix: Update default minimum torch version for aten tracer by @gs-olive in #1747
- Get windows build working by @bharrisau in #1711
- Update config.yml by @frank-wei in #1736
- fix: Bugfix in shape analysis for multi-GPU systems by @gs-olive in #1765
- fix: Add schemas to convolution lowering pass by @gs-olive in #1728...
Torch-TensorRT v1.3.0
PyTorch 1.13, CUDA 11.7, TensorRT 8.5, Support for Dynamic Batch for Partially Compiled Modules, Engine Profiling, Experimental Unified Runtime for FX and TorchScript Frontends
Torch-TensorRT 1.3.0 targets PyTorch 1.13, CUDA 11.7, cuDNN 8.5 and TensorRT 8.5. This release focuses on adding support for Dynamic Batch Sizes for partially compiled modules using the TorchScript frontend (this is also supported with the FX frontend). It also introduces a new execution profiling utility to understand the execution of specific engine sub blocks that can be used in conjunction with PyTorch profiling tools to understand the performance of your model post compilation. Finally this release introduces a new experimental unified runtime shared by both the TorchScript and FX frontends. This allows you to start using the FX frontend to generate torch.jit.trace
able compiled modules.
Dynamic Batch Sizes for Partially Compiled Modules via the TorchScript Frontend
A long-standing limitation of the partitioning system in the TorchScript function is lack of support for dynamic shapes. In this release we address a major subset of these use cases with support for dynamic batch sizes for modules that will be partially compiled. Usage is the same as the fully compiled workflow where using the torch_tensorrt.Input
class, you may define the range of shapes that an input may take during runtime. This is represented as a set of 3 shape sizes: min
, max
and opt
. min
and max
define the dynamic range of the input Tensor. opt
informs TensorRT what size to optimize for provided there are multiple valid kernels available. TensorRT will select kernels that are valid for the full range of input shapes but most efficient at the opt
size. In this release, partially compiled module inputs can vary in shape for the highest order dimension.
For example:
min_shape: (1, 3, 128, 128)
opt_shape: (8, 3, 128, 128)
max_shape: (32, 3, 128, 128)
Is a valid shape range, however:
min_shape: (1, 3, 128, 128)
opt_shape: (1, 3, 256, 256)
max_shape: (1, 3, 512, 512)
is still not supported.
Engine Profiling [Experimental]
This release introduces a number of profiling tools to measure the performance of TensorRT sub blocks in compiled modules. This can be used in conjunction with PyTorch profiling tools to get a picture of the performance of your model. Profiling for any particular sub block can be enabled by the enabled_profiling()
method of any __torch__.classes.tensorrt.Engine
attribute, or of any torch_tensorrt.TRTModuleNext
. The profiler will dump trace files by default in /tmp
, though this path can be customized by either setting the profile_path_prefix
of __torch__.classes.tensorrt.Engine
or as an argument to torch_tensorrt.TRTModuleNext.enable_precision(profiling_results_dir="")
. Traces can be visualized using the Perfetto tool (https://perfetto.dev)
Engine Layer information can also be accessed using get_layer_info
which returns a JSON string with the layers / fusions that the engine contains.
Unified Runtime for FX and TorchScript Frontends [Experimental]
In previous versions of Torch-TensorRT, the FX and TorchScript frontends were mostly separate and each had their distinct benefits and limitations. Torch-TensorRT 1.3.0 introduces a new unified runtime to support both FX and TorchScript meaning that you can choose the compilation workflow that makes the most sense for your particular use case, be it pure Python conversion via FX or C++ Torchscript compilation. Both frontends use the same primitives to construct their compiled graphs be it fully compiled or just partially.
Basic Usage
The TorchScript frontend uses the new runtime by default. No additional workflow changes are necessary.
Note: The runtime ABI version was increased to support this feature, as such models compiled with previous versions of Torch-TensorRT will need to be recompiled
For the FX frontend, the new runtime can be chosen but setting use_experimental_fx_rt=True
as part of your compile settings to either torch_tensorrt.compile(my_mod, ir="fx", use_experimental_fx_rt=True, explicit_batch_dimension=True)
or torch_tensorrt.fx.compile(my_mod, use_experimental_fx_rt=True, explicit_batch_dimension=True)
Note: The new runtime only supports explicit batch dimension
TRTModuleNext
The FX frontend will return a torch.nn.Module
containing torch_tensorrt.TRTModuleNext
submodules instead of torch_tensorrt.fx.TRTModule
s. The features of these modules are nearly identical but with a few key improvements.
TRTModuleNext
profiling dumps a trace visualizable with Perfetto (see above for more details).TRTModuleNext
modules aretorch.jit.trace
-able, meaning you can save FX compiled modules as TorchScript for python-less / C++ deployment scenarios. Traced compiled modules have the same deployment instructions as compiled modules produced by the TorchScript frontend.- TRTModuleNext maintains the same serialization workflows
TRTModule
supports as well (state_dict / extra_state, torch.save/torch.load)
Examples
model_fx = model_fx.cuda()
inputs_fx = [i.cuda() for i in inputs_fx]
trt_fx_module_f16 = torch_tensorrt.compile(
model_fx,
ir="fx",
inputs=inputs_fx,
enabled_precisions={torch.float16},
use_experimental_fx_rt=True,
explicit_batch_dimension=True
)
# Save model using torch.save
torch.save(trt_fx_module_f16, "trt.pt")
reload_trt_mod = torch.load("trt.pt")
# Trace and save the FX module in TorchScript
scripted_fx_module = torch.jit.trace(trt_fx_module_f16, example_inputs=inputs_fx)
scripted_fx_module.save("/tmp/scripted_fx_module.ts")
scripted_fx_module = torch.jit.load("/tmp/scripted_fx_module.ts")
... #Get a handle for a TRTModuleNext submodule
# Extract state dictionary
st = trt_mod.state_dict()
# Load the state dict into a new module
new_trt_mod = TRTModuleNext()
new_trt_mod.load_state_dict(st)
Using TRTModuleNext as an arbirary TensorRT engine holder
Using TorchScript you have long been able to embed an arbritrary TensorRT engine from any source in a TorchScript module using torch_tensorrt.ts.embed_engine_in_new_module
. Now you can do this at the torch.nn.Module
level by directly using TRTModuleNext
and access all the benefits enumerated above.
trt_mod = TRTModuleNext(
serialized_engine,
name="TestModule",
input_binding_names=input_names,
output_binding_names=output_names,
)
The intention is in a future release to have torch_tensorrt.TRTModuleNext
replace torch_tensorrt.fx.TRTModule
as the default TensorRT Module implementation. Feedback on this class or how it is used, the runtime in general or associated features (profiler, engine inspector) is welcomed.
What's Changed
- chore: Bump version to 1.2.0a0 by @narendasan in #1044
- feat: Extending nox for cxx11 ABI version by @andi4191 in #1013
- docs: Update the documentation theme to PyTorch by @narendasan in #1063
- Adding Code of Conduct file by @facebook-github-bot in #1061
- Update CONTRIBUTING.md by @frank-wei in #1064
- feat: Optimize hub.py download by @andi4191 in #1022
- Adding an action to automatically assign reviewers and assignees by @narendasan in #1078
- Add PR assigner support by @narendasan in #1080
- (//core): Align with prim::Enter in module fallback by @andi4191 in #991
- (//core): Added a variant for aten::split by @andi4191 in #992
- feat(nox): Replacing session with environment variable by @andi4191 in #1057
- Refactor the internal codebase from fx2trt_oss to torch_tensorrt by @frank-wei in #1104
- format by buildifier by @frank-wei in #1106
- [fx2trt] Modify lower setting class by @frank-wei in #1107
- Modified the notebooks directory's README file by @svenchilton in #1102
- [FX] Sync to OSS by @frank-wei in #1118
- [fx_acc] Add acc_tracer support for torch.mm by @khabinov in #1120
- Added Triton deployment instructions to documentation by @tanayvarshney in #1116
- amending triton deployment docs by @tanayvarshney in #1126
- fix: Update broken repo hyperlink by @lamhoangtung in #1131
- fix: Fix keep_dims functionality for aten::max by @peri044 in #1099
- fix(tests/core/partitioning): Fix tests of refactoring segmentation in partitioning by @peri044 in #1140
- feat(//tests): Update rtol and atol based tolerance for test cases by @andi4191 in #1055
- doc: add the explanation for partition phases on docs by @bowang007 in #1090
- feat (//cpp): Using atol and rtol based tolerance threshold for torchtrtc by @andi4191 in #1052
- CI/CD setup by @frank-wei in #1137
...
Torch-TensorRT v1.2.0
PyTorch 1.12, Collections based I/O, FX Frontend, torchtrtc custom op support, CMake build system and Community Window Support
Torch-TensorRT 1.2.0 targets PyTorch 1.12, CUDA 11.6, cuDNN 8.4 and TensorRT 8.4. This release focuses on a couple key new APIs to handle function I/O that uses collection types which should enable whole new model classes to be compiled by Torch-TensorRT without source code modification. It also introduces the "FX Frontend", a new frontend for Torch-TensorRT which leverages FX, a high level IR built into PyTorch with extensive Python APIs. For uses cases which do not need to be run outside of Python this may be a strong option to try as it is easily extensible in a familar development enviornment. In Torch-TensorRT 1.2.0, the FX frontend should be considered beta level in stability. torchtrtc
has received improvements which target the ability to handle operators outside of the core PyTorch op set. This includes custom operators from libraries such as torchvision
and torchtext
. Similarlly users can provide custom converters to torchtrtc to extend the compilers support from the command line instead of having to write an application to do so. Finally, Torch-TensorRT introduces community supported Windows and CMake support.
New Dependencies
nvidia-tensorrt
For previous versions of Torch-TensorRT, users had to install TensorRT via system package manager and modify their LD_LIBRARY_PATH
in order to set up Torch-TensorRT. Now users should install the TensorRT Python API as part of the installation proceedure. This can be done via the following steps:
pip install nvidia-pyindex
pip install nvidia-tensorrt==8.4.3.1
pip install torch-tensorrt==1.2.0 -f https://github.com/pytorch/tensorrt/releases
Installing the TensorRT pip package will allow Torch-TensorRT to automatically load the TensorRT libraries without any modification to enviornment variables. It is also a necessary dependency for the FX Frontend.
torchvision
Some FX frontend converters are designed to target operators from 3rd party libraries like torchvision. As such, you must have torchvision installed in order to use them. However, this dependency is optional for cases where you do not need this support.
Jetson
Starting from this release we will be distributing precompiled binaries of our NGC release branches for aarch64 (as well as x86_64), starting with ngc/22.11. These releases are designed to be paired with NVIDIA distributed builds of PyTorch including the NGC containers and Jetson builds and are equivalent to the prepackaged distribution of Torch-TensorRT that comes in the containers. They represent the state of the master branch at the time of branch cutting so may lag in features by a month or so. These releases will come separately to minor version releases like this one. Therefore going forward, these NGC releases should be the primary release channel used on Jetson (including for building from source).
NOTE: NGC PyTorch builds are not identical to builds you might install through normal channels like pytorch.org. In the past this has caused issues in portability between pytorch.org builds and NGC builds. Therefore we strongly recommend in workflows such as exporting a TorchScript module on an x86 machine and then compiling on Jetson to ensure you are using the NGC container release on x86 for your host machine operations. More information about Jetson support can be found along side the 22.07 release (https://github.com/pytorch/TensorRT/releases/tag/v1.2.0a0.nv22.07)
Collections based I/O [Experimental]
Torch-TensorRT previously has operated under the assumption that nn.Module
forward functions can trivially be reduced to the form forward([Tensor]) -> [Tensor]
. Typically this implies functions fo the form forward(Tensor, Tensor, ... Tensor) -> (Tensor, Tensor, ..., Tensor)
. However as model complexity increases, grouping inputs may make it easier to manage many inputs. Therefore, function signatures similar to forward([Tensor], (Tensor, Tensor)) -> [Tensor]
or forward((Tensor, Tensor)) -> (Tensor, (Tensor, Tensor))
might be more common. In Torch-TensorRT 1.2.0, more of these kinds of uses cases are supported using the new experimental input_signature
compile spec API. This API allows users to group Input specs similar to how they might group the input Tensors they would use to call the original module's forward function. This informs Torch-TensorRT on how to map a Tensor input from its location in a group to the engine and from the engine into its grouping returned back to the user.
To make this concrete consider the following standard case:
class StandardTensorInput(nn.Module):
def __init__(self):
super(StandardTensorInput, self).__init__()
def forward(self, x, y):
r = x + y
return r
x = torch.Tensor([1,2,3]).to("cuda")
y = torch.Tensor([4,5,6]).to("cuda")
module = StandardTensorInput().eval().to("cuda")
trt_module = torch_tensorrt.compile(
module,
inputs=[
torch_tensorrt.Input(x.shape),
torch_tensorrt.Input(y.shape)
],
min_block_size=1
)
out = trt_module(x,y)
print(out)
Here a user has defined two explicit tensor inputs and used the existing list based API to define the input specs.
With Torch-TensorRT the following use cases are now possible using the new input_signature
API:
- Tuple based input collection
class TupleInput(nn.Module):
def __init__(self):
super(TupleInput, self).__init__()
def forward(self, z: Tuple[torch.Tensor, torch.Tensor]):
r = z[0] + z[1]
return r
x = torch.Tensor([1,2,3]).to("cuda")
y = torch.Tensor([4,5,6]).to("cuda")
module = TupleInput().eval().to("cuda")
trt_module = torch_tensorrt.compile(
module,
input_signature=((x, y),), # Note how inputs are grouped with the new API
min_block_size=1
)
out = trt_module((x,y))
print(out)
- List based input collection
class ListInput(nn.Module):
def __init__(self):
super(ListInput, self).__init__()
def forward(self, z: List[torch.Tensor]):
r = z[0] + z[1]
return r
x = torch.Tensor([1,2,3]).to("cuda")
y = torch.Tensor([4,5,6]).to("cuda")
module = ListInput().eval().to("cuda")
trt_module = torch_tensorrt.compile(
module,
input_signature=([x,y],), # Again, note how inputs are grouped with the new API
min_block_size=1
)
out = trt_module([x,y])
print(out)
Note how the input specs (in this case just example tensors) are provided to the compiler. The input_signature
argument expects a Tuple[Union[torch.Tensor, torch_tensorrt.Input, List, Tuple]]
grouped in a format representative of how the function would be called. In these cases its just a list or tuple of specs.
More advanced cases are supported as we:
- Tuple I/O
class TupleInputOutput(nn.Module):
def __init__(self):
super(TupleInputOutput, self).__init__()
def forward(self, z: Tuple[torch.Tensor, torch.Tensor]):
r1 = z[0] + z[1]
r2 = z[0] - z[1]
r1 = r1 * 10
r = (r1, r2)
return r
x = torch.Tensor([1,2,3For previous versions of Torch-TensorRT, users had to install TensorRT via ]).to("cuda")
y = torch.Tensor([4,5,6]).to("cuda")
module = TupleInputOutput()
trt_module = torch_tensorrt.compile(
module,
input_signature=((x,y),), # Again, note how inputs are grouped with the new API
min_block_size=1
)
out = trt_module((x,y))
print(out)
- List I/O
class ListInputOutput(nn.Module):
def __init__(self):
super(ListInputOutput, self).__init__()
def forward(self, z: List[torch.Tensor]):
r1 = z[0] + z[1]
r2 = z[0] - z[1]
r = [r1, r2]
return r
x = torch.Tensor([1,2,3]).to("cuda")
y = torch.Tensor([4,5,6]).to("cuda")
module = ListInputOutput()
trt_module = torch_tensorrt.compile(
module,
input_signature=([x,y],), # Again, note how inputs are grouped with the new API
min_block_size=1
)
out = trt_module((x,y))
print(out)
- Multple Groups of Mixed Types
class MultiGroupIO(nn.Module):
def __init__(self):
super(MultiGroupIO, self).__init__()
def forward(self, z: List[torch.Tensor], a: Tuple[torch.Tensor, torch.Tensor]):
r1 = z[0] + z[1]
r2 = a[0] + a[1]
r3 = r1 - r2
r4 = [r1, r2]
return (r3, r4)
x = torch.Tensor([1,2,3]).to("cuda")
y = torch.Tensor([4,5,6]).to("cuda")
module = MultiGroupIO().eval.to("cuda")
trt_module = torch_tensorrt.compile(
module,
input_signature=([x,y],(x,y)), # Again, note how inputs are grouped with the new API
min_block_size=1
)
out = trt_module([x,y],(x,y))
print(out)
These features are also supported in C++ as well:
torch::jit::Module mod;
try {
// Deserialize the ScriptModule from a file using torch::jit::load().
mod = torch::jit::load(path);
} catch (const c10::Error& e) {
std::cerr << "error loading the model\n";
}
mod.eval();
mod.to(torch::kCUDA);
std::vector<torch::jit::IValue> inputs_;
for (auto in : inputs) {
inputs_.push_back(torch::jit::IValue(in.clone()));
}
std::vector<torch::jit::IValue> complex_inputs;
auto input_list = c10::impl::GenericList(c10::TensorType::get());
input_list.push_back(inputs_[0]);
input_list.push_back(inputs_[0]);
torch::jit::IValue input_list_ivalue = torch::jit::IValue(input_list);
complex_inputs.push_back(input_list_ivalue);
auto input_shape = torch_tensorrt::Input(in0.sizes(), torch_tensorrt::DataType::kHalf);
auto input_shape_ivalue = torch::jit::IValue(std::move(c10::make_intrusive<torch_tensorrt::Input>(input_shape)));
c10::TypePtr elementType = input_shape_ivalue.type();
auto ...
Torch-TensorRT v1.1.1
Adding support for Torch-TensorRT on Jetpack 5.0 Developer Preview
Torch-TensorRT 1.1.1 is a patch release for Torch-TensorRT 1.1 that targets PyTorch 1.11, CUDA 11.4/11.3, TensorRT 8.4 EA/8.2 and cuDNN 8.3/8.2 intended to add support for Torch-TensorRT on Jetson / Jetpack 5.0 DP. As this release is primarily targeted at adding support for Jetpack 5.0DP for the 1.1 feature set we will not be distributing pre-compiled binaries for this release so as not to break compatibility with the current stack for existing users who install directly from GitHub. Please follow the instructions for installation on Jetson in the documentation to install this release: https://pytorch.org/TensorRT/tutorials/installation.html#compiling-from-source
Known Limitations
- We have observed in testing, higher than normal numerical instability on Jetpack 5.0 DP. These issues are not observed on x86_64 based platforms. This numerical instability has not been found to decrease model accuracy in our test suite.
What's Changed
Full Changelog: v1.1.0...v1.1.1
Operators Supported
Operators Currently Supported Through Converters
- aten::_convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups, bool benchmark, bool deterministic, bool cudnn_enabled, bool allow_tf32) -> (Tensor)
- aten::_convolution.deprecated(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups, bool benchmark, bool deterministic, bool cudnn_enabled) -> (Tensor)
- aten::abs(Tensor self) -> (Tensor)
- aten::acos(Tensor self) -> (Tensor)
- aten::acosh(Tensor self) -> (Tensor)
- aten::adaptive_avg_pool1d(Tensor self, int[1] output_size) -> (Tensor)
- aten::adaptive_avg_pool2d(Tensor self, int[2] output_size) -> (Tensor)
- aten::adaptive_avg_pool3d(Tensor self, int[3] output_size) -> (Tensor)
- aten::adaptive_max_pool1d(Tensor self, int[2] output_size) -> (Tensor, Tensor)
- aten::adaptive_max_pool2d(Tensor self, int[2] output_size) -> (Tensor, Tensor)
- aten::adaptive_max_pool3d(Tensor self, int[3] output_size) -> (Tensor, Tensor)
- aten::add.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> (Tensor)
- aten::add.Tensor(Tensor self, Tensor other, Scalar alpha=1) -> (Tensor)
- aten::add_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> (Tensor(a!))
- aten::asin(Tensor self) -> (Tensor)
- aten::asinh(Tensor self) -> (Tensor)
- aten::atan(Tensor self) -> (Tensor)
- aten::atanh(Tensor self) -> (Tensor)
- aten::avg_pool1d(Tensor self, int[1] kernel_size, int[1] stride=[], int[1] padding=[0], bool ceil_mode=False, bool count_include_pad=True) -> (Tensor)
- aten::avg_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=[0, 0], bool ceil_mode=False, bool count_include_pad=True, int? divisor_override=None) -> (Tensor)
- aten::avg_pool3d(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=[], bool ceil_mode=False, bool count_include_pad=True, int? divisor_override=None) -> (Tensor)
- aten::batch_norm(Tensor input, Tensor? gamma, Tensor? beta, Tensor? mean, Tensor? var, bool training, float momentum, float eps, bool cudnn_enabled) -> (Tensor)
- aten::bmm(Tensor self, Tensor mat2) -> (Tensor)
- aten::cat(Tensor[] tensors, int dim=0) -> (Tensor)
- aten::ceil(Tensor self) -> (Tensor)
- aten::clamp(Tensor self, Scalar? min=None, Scalar? max=None) -> (Tensor)
- aten::clamp_max(Tensor self, Scalar max) -> (Tensor)
- aten::clamp_min(Tensor self, Scalar min) -> (Tensor)
- aten::constant_pad_nd(Tensor self, int[] pad, Scalar value=0) -> (Tensor)
- aten::cos(Tensor self) -> (Tensor)
- aten::cosh(Tensor self) -> (Tensor)
- aten::cumsum(Tensor self, int dim, *, int? dtype=None) -> (Tensor)
- aten::div.Scalar(Tensor self, Scalar other) -> (Tensor)
- aten::div.Tensor(Tensor self, Tensor other) -> (Tensor)
- aten::div.Tensor_mode(Tensor self, Tensor other, *, str? rounding_mode) -> (Tensor)
- aten::div_.Scalar(Tensor(a!) self, Scalar other) -> (Tensor(a!))
- aten::div_.Tensor(Tensor(a!) self, Tensor other) -> (Tensor(a!))
- aten::elu(Tensor self, Scalar alpha=1, Scalar scale=1, Scalar input_scale=1) -> (Tensor)
- aten::embedding(Tensor weight, Tensor indices, int padding_idx=-1, bool scale_grad_by_freq=False, bool sparse=False) -> (Tensor)
- aten::eq.Scalar(Tensor self, Scalar other) -> (Tensor)
- aten::eq.Tensor(Tensor self, Tensor other) -> (Tensor)
- aten::erf(Tensor self) -> (Tensor)
- aten::exp(Tensor self) -> (Tensor)
- aten::expand(Tensor(a) self, int[] size, *, bool implicit=False) -> (Tensor(a))
- aten::expand_as(Tensor(a) self, Tensor other) -> (Tensor(a))
- aten::fake_quantize_per_channel_affine(Tensor self, Tensor scale, Tensor zero_point, int axis, int quant_min, int quant_max) -> (Tensor)
- aten::fake_quantize_per_tensor_affine(Tensor self, float scale, int zero_point, int quant_min, int quant_max) -> (Tensor)
- aten::flatten.using_ints(Tensor self, int start_dim=0, int end_dim=-1) -> (Tensor)
- aten::floor(Tensor self) -> (Tensor)
- aten::floor_divide(Tensor self, Tensor other) -> (Tensor)
- aten::floor_divide.Scalar(Tensor self, Scalar other) -> (Tensor)
- aten::ge.Scalar(Tensor self, Scalar other) -> (Tensor)
- aten::ge.Tensor(Tensor self, Tensor other) -> (Tensor)
- aten::gru_cell(Tensor input, Tensor hx, Tensor w_ih, Tensor w_hh, Tensor? b_ih=None, Tensor? b_hh=None) -> (Tensor)
- aten::gt.Scalar(Tensor self, Scalar other) -> (Tensor)
- aten::gt.Tensor(Tensor self, Tensor other) -> (Tensor)
- aten::hardtanh(Tensor self, Scalar min_val=-1, Scalar max_val=1) -> (Tensor)
- aten::hardtanh_(Tensor(a!) self, Scalar min_val=-1, Scalar max_val=1) -> (Tensor(a!))
- aten::index.Tensor(Tensor self, Tensor?[] indices) -> (Tensor)
- aten::instance_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool use_input_stats, float momentum, float eps, bool cudnn_enabled) -> (Tensor)
- aten::layer_norm(Tensor input, int[] normalized_shape, Tensor? gamma, Tensor? beta, float eps, bool cudnn_enabled) -> (Tensor)
- aten::le.Scalar(Tensor self, Scalar other) -> (Tensor)
- aten::le.Tensor(Tensor self, Tensor other) -> (Tensor)
- aten::leaky_relu(Tensor self, Scalar negative_slope=0.01) -> (Tensor)
- aten::leaky_relu_(Tensor(a!) self, Scalar negative_slope=0.01) -> (Tensor(a!))
- aten::linear(Tensor input, Tensor weight, Tensor? bias=None) -> (Tensor)
- aten::log(Tensor self) -> (Tensor)
- aten::lstm_cell(Tensor input, Tensor[] hx, Tensor w_ih, Tensor w_hh, Tensor? b_ih=None, Tensor? b_hh=None) -> (Tensor, Tensor)
- aten::lt.Scalar(Tensor self, Scalar other) -> (Tensor)
- aten::lt.Tensor(Tensor self, Tensor other) -> (Tensor)
- aten::masked_fill.Scalar(Tensor self, Tensor mask, Scalar value) -> (Tensor)
- aten::matmul(Tensor self, Tensor other) -> (Tensor)
- aten::max(Tensor self) -> (Tensor)
- aten::max.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices)
- aten::max.other(Tensor self, Tensor other) -> (Tensor)
- aten::max_pool1d(Tensor self, int[1] kernel_size, int[1] stride=[], int[1] padding=[], int[1] dilation=[], bool ceil_mode=False) -> (Tensor)
- aten::max_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=[0, 0], int[2] dilation=[1, 1], bool ceil_mode=False) -> (Tensor)
- aten::max_pool3d(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=[], int[3] dilation=[], bool ceil_mode=False) -> (Tensor)
- aten::mean(Tensor self, *, int? dtype=None) -> (Tensor)
- aten::mean.dim(Tensor self, int[] dim, bool keepdim=False, *, int? dtype=None) -> (Tensor)
- aten::min(Tensor self) -> (Tensor)
- aten::min.other(Tensor self, Tensor other) -> (Tensor)
- aten::mul.Scalar(Tensor self, Scalar other) -> (Tensor)
- aten::mul.Tensor(Tensor self, Tensor other) -> (Tensor)
- aten::mul_.Tensor(Tensor(a!) self, Tensor other) -> (Tensor(a!))
- aten::narrow(Tensor(a) self, int dim, int start, int length) -> (Tensor(a))
- aten::narrow.Tensor(Tensor(a) self, int dim, Tensor start, int length) -> (Tensor(a))
- aten::ne.Scalar(Tensor self, Scalar other) -> (Tensor)
- aten::ne.Tensor(Tensor self, Tensor other) -> (Tensor)
- aten::neg(Tensor self) -> (Tensor)
- aten::norm.ScalarOpt_dim(Tensor self, Scalar? p, int[1] dim, bool keepdim=False) -> (Tensor)
- aten::permute(Tensor(a) self, int[] dims) -> (Tensor(a))
- aten::pixel_shuffle(Tensor self, int upscale_factor) -> (Tensor)
- aten::pow.Tensor_Scalar(Tensor self, Scalar exponent) -> (Tensor)
- aten::pow.Tensor_Tensor(Tensor self, Tensor exponent) -> (Tensor)
- aten::prelu(Tensor self, Tensor weight) -> (Tensor)
- aten::prod(Tensor self, *, int? dtype=None) -> (Tensor)
- aten::prod.dim_int(Tensor self, int dim, bool keepdim=False, *, int? dtype=None) -> (Tensor)
- aten::reciprocal(Tensor self) -> (Tensor)
- aten::reflection_pad1d(Tensor self, int[2] padding) -> (Tensor)
- aten::reflection_pad2d(Tensor self, int[4] padding) -> (Tensor)
- aten::relu(Tensor input) -> (Tensor)
- aten::relu_(Tensor(a!) self) -> (Tensor(a!))
- aten::repeat(Tensor self, int[] repeats) -> (Tensor)
- aten::replication_pad1d(Tensor self, int[2] padding) -> (Tensor)
- aten::replication_pad2d(Tensor self, int[4] padding) -> (Tensor)
- aten::replication_pad3d(Tensor self, int[6] padding) -> (Tensor)
- aten::reshape(Tensor self, int[] shape) -> (Tensor)
- aten::roll(Tensor self, int[1] shifts, int[1] dims=[]) -> (Tensor)
- aten::rsub.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> (Tensor)
- aten::rsub.Tensor(Tensor self, Tensor other, Scalar alpha=1) -> (Tensor)
- aten::select.int(Tensor(a) self, int dim, int i...
Torch-TensorRT v1.1.0
Support for PyTorch 1.11, Various Bug Fixes, Partial aten::Int
support, New Debugging Tools, Removing Max Batch Size
Torch-TensorRT 1.1.0 targets PyTorch 1.11, CUDA 11.3, cuDNN 8.2 and TensorRT 8.2. Due to recent JetPack upgrades, this release does not support Jetson (Jetpack 5.0DP or otherwise). Jetpack 5.0DP support will arrive in a mid-cycle release (Torch-TensorRT 1.1.x) along with support for TensorRT 8.4. 1.1.0 also drops support for Python 3.6 as it has reached end of life. Following 1.0.0, this release is focused on stabilizing and improving the core of Torch-TensorRT. Many improvements have been made to the partitioning system addressing limitation many users hit while trying to partially compile PyTorch modules. Torch-TensorRT 1.1.0 also addresses a long standing issue with aten::Int
operators (albeit) partially. Now certain common patterns which use aten::Int
can be handled by the compiler without resorting to partial compilation. Most notably, this means that models like BERT can be run end to end with Torch-TensorRT, resulting in significant performance gains.
New Debugging Tools
With this release we are introducing new syntax sugar that can be used to more easily debug Torch-TensorRT compilation and execution through the use of context managers. For example, in Torch-TensorRT 1.0.0 this may be a common pattern to turn on then turn off debug info:
import torch_tensorrt
...
torch_tensorrt.logging.set_reportable_log_level(torch_tensorrt.logging.Level.Debug)
trt_module = torch_tensorrt.compile(my_module, ...)
torch_tensorrt.logging.set_reportable_log_level(torch_tensorrt.logging.Level.Warning)
results = trt_module(input_tensors)
With Torch-TensorRT 1.1.0, this now can be done with the following code:
import torch_tensorrt
...
with torch_tensorrt.logging.debug():
trt_module = torch_tensorrt.compile(my_module,...)
results = trt_module(input_tensors)
You can also use this API to debug the Torch-TensorRT runtime as well:
import torch_tensorrt
torch_tensorrt.logging.set_reportable_log_level(torch_tensorrt.logging.Level.Error)
...
trt_module = torch_tensorrt.compile(my_module,...)
with torch_tensorrt.logging.warnings():
results = trt_module(input_tensors)
The following levels are available:
# Only internal TensorRT failures will be logged
with torch_tensorrt.logging.internal_errors():
# Internal TensorRT failures + Torch-TensorRT errors will be logged
with torch_tensorrt.logging.errors():
# All Errors plus warnings will be logged
with torch_tensorrt.logging.warnings():
# First verbosity level, information about major steps occurring during compilation and execution
with torch_tensorrt.logging.info():
# Second verbosity level, each step is logged + information about compiler state will be outputted
with torch_tensorrt.logging.debug():
# Third verbosity level, all above information + intermediate transformations of the graph during lowering
with torch_tensorrt.logging.graphs():
Removing Max Batch Size, Strict Types
In this release we are removing the max_batch_size
and strict_types
settings. These settings directly corresponded to the TensorRT settings, however were not always respected which often lead to confusion. Therefore we thought it best to disable these features as deterministic behavior could not be ensured.
Porting forward from max_batch_size
, strict_types
:
max_batch_size
: The first dim in shapes provided to Torch-TensorRT are considered batch dimensions, therefore instead of settingmax_batch_size
, you can just use the Input objects directlystrict_types
: A replacement with more deterministic behavior will come with an upcoming TensorRT release.
Dependencies
- Bazel 5.1.1
- LibTorch 1.11.0
- CUDA 11.3 (on x86_64, by default, newer CUDA 11 supported with compatible PyTorch Build)
- cuDNN 8.2.4.15
- TensorRT 8.2.4.2
1.1.0 (2022-05-10)
Bug Fixes
- add at::adaptive_avg_pool1d in interpolate plugin and fix #791 (deb9f74)
- Added ipywidget dependency to notebook (0b2040a)
- Added test case names (296e98a)
- Added truncate_long_and_double (417c096)
- Adding truncate_long_and_double to ptq tests (3a0640a)
- Avoid resolving non-tensor inputs to torch segment_blocks unneccessarily (3e090ee)
- Considering rtol and atol in threshold comparison for floating point numbers (0b0ba8d)
- Disabled mobilenet_v2 test for DLFW CI (40c611f)
- fix bug that python api doesn't pass truncate_long_and_double value to internal.partition_info (828336d)
- fix bugs in aten::to (2ecd187)
- Fix BUILD file for tests/accuracy (8b0170e)
- Fix existing uninstallation of Torch-TRT (9ddd7a8)
- Fix for torch scripted module faiure with DLFW (88c02d9)
- Fix fuse addmm pass (58e9ea0)
- Fix pre_built name change in bazelrc (3ecee21)
- fix the bug that introduces kLong Tensor in prim::NumToTensor (2c3e1d9)
- Fix when TRT prunes away an output (9465e1d)
- Fixed bugs and addressed review comments (588e1d1)
- Fixed failures for host deps sessions (ec2232f)
- Fixed typo in the path (43fab56)
- Getting unsupported ops will now bypass non-schema ops avoiding redundant failures (d7d1511)
- Guard test activation for CI testing (6d1a1fd)
- Implement a patch for gelu schema change in older NGC containers (9ee3a04)
- Missing log severity (6a4daef)
- Preempt torch package override via timm in nox session (8964d1b)
- refactor the resegmentation for TensorRT segments in ResolveNonTensorInput (3cc2dfb)
- remove outdated member variables (0268da2)
- Removed models directory dependencies (c4413e1)
- Resolve issues in exception elmination pass (99cea1b)
- Review comments incorporated (962660d)
- Review comments incorporated (e9865c2)
- support dict type for input in shape analysis (630f9c4)
- truncate_long_and_double incur torchscript inference issues (c83aa15)
- Typo fix for test case name (2a516b2)
- Update "reduceAxes" variable in GlobalPoolingConverter function and add corresponding uTests (f6f5e3e)
- //core/conversion/evaluators: Change how schemas are handled (20e5d41)
- Update base container for dockerfile (1b3245a)
- //core: Take user setting in the case we can't determine the (01c89d1), closes #814
- Update test for new Exception syntax (2357099)
- //core/conversion: Add special case for If and Loop (eacde8d)
- //core/runtime: Support more delimiter variants (819c911)
- //cpp/bin/torchtrtc: Fix mbs (aca175f)
- //docsrc: Fix dependencies for docgen (806e663)
- //notebooks: Render citrinet (12dbda1)
- //py: Constrain the CUDA version in container builds (a21a045)
- Use user provided dtype when we can't infer it from the graph (14650d1)
Code Refactoring
- removing the strict_types and max_batch_size apis (b30cbd9)
- Rename enabled precisions arugment to (10957eb)
- Removing the max-batch-size argument (03bafc5)
Features
- //core/conversion: Better tooling for debugging (c5c5c47)
- //core/conversion/evaluators: aten::pow support (c4fdfcb)
- //docker: New base container to let master build in container ([446bf18](https://github.com...
Torch-TensorRT v1.0.0
New Name!, Support for PyTorch 1.10, CUDA 11.3, New Packaging and Distribution Options, Stabilized APIs, Stabilized Partial Compilation, Adjusted Default Behavior, Usability Improvements, New Converters, Bug Fixes
This is the first stable release of Torch-TensorRT targeting PyTorch 1.10, CUDA 11.3 (on x86_64, CUDA 10.2 on aarch64), cuDNN 8.2 and TensorRT 8.0 with backwards compatible source for TensorRT 7.1. On aarch64 TRTorch targets Jetpack 4.6 primarily with backwards compatible source for Jetpack 4.5. This version also removes deprecated APIs such as InputRange
and op_precision
New Name
TRTorch is now Torch-TensorRT! TRTorch started out as a small experimental project compiling TorchScript to TensorRT almost two years ago and now as we are hitting v1.0.0 with APIs and major features stabilizing we felt that the name of the project should reflect the ecosystem of tools it is joining with this release, namely TF-TRT (https://blog.tensorflow.org/2021/01/leveraging-tensorflow-tensorrt-integration.html) and MXNet-TensorRT(https://mxnet.apache.org/versions/1.8.0/api/python/docs/tutorials/performance/backend/tensorrt/tensorrt). Since we were already significantly changing APIs with this release to reflect what we learned over the last two years of using TRTorch, we felt this is was the right time to change the name as well.
The overall process to port forward from TRTorch is as follows:
-
Python
- The library has been renamed from
trtorch
totorch_tensorrt
- Components that used to all live under the
trtorch
namespace have now been separated. IR agnostic components:torch_tensorrt.Input
,torch_tensorrt.Device
,torch_tensorrt.ptq
,torch_tensorrt.logging
will continue to live under the top level namespace. IR specific components liketorch_tensorrt.ts.compile
,torch_tensorrt.ts.convert_method_to_trt_engine
,torch_tensorrt.ts.TensorRTCompileSpec
will live in a TorchScript specific namespace. This gives us space to explore the other IRs that might be relevant to the project in the future. In the place of the old top levelcompile
andconvert_method_to_engine
are new ones which will call the IR specific versions based on what is provided to them. This also means that you can now provide a rawtorch.nn.Module
totorch_tensorrt.compile
and Torch-TensorRT will handle the TorchScripting step for you. For the most part the sole change that will be needed to change over namespaces is to exchangetrtorch
totorch_tensorrt
- The library has been renamed from
-
C++
- Similar to Python the namespaces in C++ have changed from
trtorch
totorch_tensorrt
and components specific to the IR likecompile
,convert_method_to_trt_engine
andCompileSpec
are in atorchscript
namespace, while agnostic components are at the top level. Namespace aliases fortorch_tensorrt
->torchtrt
andtorchscript
->ts
are included. Again the port forward process for namespaces should be a find and replace. Finally the librarieslibtrtorch.so
,libtrtorchrt.so
andlibtrtorch_plugins.so
have been renamed tolibtorchtrt.so
,libtorchtrt_runtime.so
andlibtorchtrt_plugins.so
respectively.
- Similar to Python the namespaces in C++ have changed from
-
CLI:
trtorch
has been renamed totorchtrtc
New Distribution Options and Packaging
Starting with nvcr.io/nvidia/pytorch:21.11
, Torch-TensorRT will be distributed as part of the container (https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch). The version of Torch-TensorRT in container will be the state of the master at the time of building. Torch-TensorRT will be validated to run correctly with the version of PyTorch, CUDA, cuDNN and TensorRT in the container. This will serve as the easiest way to have a full validated PyTorch end to end training to inference stack and serves as a great starting point for building DL applications.
Also as part of Torch-TensorRT we are now starting to distribute the full C++ package within the wheel files for the Python packages. By installing the wheel you now get the Python API, the C++ libraries + headers and the CLI binary. This is going to be the easiest way to install Torch-TensorRT on your stack. After installing with pip
pip3 install torch-tensorrt -f https://github.com/NVIDIA/Torch-TensorRT/releases
You can add the following to your PATH
to set up the CLI
PATH=$PATH:<PATH TO TORCHTRT PYTHON PACKAGE>/bin
Stabilized APIs
Python
Many of the APIs have change slighly in this release to be more self consistent and more usable. These changes begin with the Python API for which compile
, convert_method_to_trt_engine
and TensorRTCompileSpec
now instead of dictionaries use kwargs. As features many features came out of beta and experimental stability the necessity to have multiple levels of nesting in settings has decreased, therefore kwargs make much more sense. You can simply port forward to the new APIs by unwrapping your existing compile_spec
dict in the arguments to compile
or similar functions.
Example:
compile_settings = {
"inputs": [torch_tensorrt.Input(
min_shape=[1, 3, 224, 224],
opt_shape=[1, 3, 512, 512],
max_shape=[1, 3, 1024, 1024],
# For static size shape=[1, 3, 224, 224]
dtype=torch.half, # Datatype of input tensor. Allowed options torch.(float|half|int8|int32|bool)
)],
"enabled_precisions": {torch.half}, # Run with FP16
}
trt_ts_module = torch_tensorrt.compile(torch_script_module, **compile_settings)
This release also introduces support for providing tensors as examples to Torch-TensorRT. In place of a torch_tensorrt.Input
in the list of inputs you can pass a Tensor. This can only be used to set a static input size. There are also some things to be aware of which will be discussed later in the release notes.
Now that Torch-TensorRT separates components specific to particular IRs to their own namespaces, there is now a replacement for the old compile
and convert_method_to_trt_engine
functions on the top level. These functions take any PyTorch generated format including torch.nn.Module
s and decides the best way to compile it down to TensorRT. In v1.0.0 this means to go through TorchScript and return a Torch.jit.ScriptModule
. You can specify the IR to try using the ir
arg for these functions.
Due to partial compilation becoming stable in v1.0.0, there are now four new fields which replace the old torch_fallback
struct.
- old:
complie_spec = {
"torch_fallback": {
"enabled": True, # Turn on or turn off falling back to PyTorch if operations are not supported in TensorRT
"force_fallback_ops": [
"aten::max_pool2d" # List of specific ops to require running in PyTorch
],
"force_fallback_modules": [
"mypymod.mytorchmod" # List of specific torch modules to require running in PyTorch
],
"min_block_size": 3 # Minimum number of ops an engine must incapsulate to be run in TensorRT
}
}
- new:
torch_tensorrt.compile(...,
require_full_compilation=False,
min_block_size=3,
torch_executed_ops=[ "aten::max_pool2d" ],
torch_executed_modules=["mypymod.mytorchmod"])
C++
The changes for the C++ API other than the reorganization and renaming of the namespaces, mostly serve to make Torch-TensorRT consistent between Python and C++ namely by renaming trtorch::CompileGraph
to torch_tensorrt::ts::compile
and trtorch::ConvertGraphToTRTEngine
to torch_tensorrt::ts::convert_method_to_trt_engine
. Beyond that similar to Python, the partial compilation struct TorchFallback
has been removed and replaced by four fields in torch_tensorrt::ts::CompileSpec
- old:
/**
* @brief A struct to hold fallback info
*/
struct TRTORCH_API TorchFallback {
/// enable the automatic fallback feature
bool enabled = false;
/// minimum consecutive operation number that needs to be satisfied to convert to TensorRT
uint64_t min_block_size = 1;
/// A list of names of operations that will explicitly run in PyTorch
std::vector<std::string> forced_fallback_ops;
/// A list of names of modules that will explicitly run in PyTorch
std::vector<std::string> forced_fallback_modules;
/**
* @brief Construct a default Torch Fallback object, fallback will be off
*/
TorchFallback() = default;
/**
* @brief Construct from a bool
*/
TorchFallback(bool enabled) : enabled(enabled) {}
/**
* @brief Constructor for setting min_block_size
*/
TorchFallback(bool enabled, uint64_t min_size) : enabled(enabled), min_block_size(min_size) {}
};
- new:
/**
* Require the full module be compiled to TensorRT instead of potentially running unsupported operations in PyTorch
*/
bool require_full_compilation = false;
/**
* Minimum number of contiguous supported operators to compile a subgraph to TensorRT
*/
uint64_t min_block_size = 3;
/**
* List of aten operators that must be run in PyTorch. An error will be thrown if this list is not empty but
* ``require_full_compilation`` is True
*/
std::vector<std::string> torch_executed_ops;
/**
* List of modules that must be run in PyTorch. An error will be thrown if this list is not empty but
* ``require_full_compilation`` is True
*/
std::vector<std::string> torch_executed_modules;
CLI
Similarly these partial compilation fields have been renamed in torchtrtc
:
--require-full-compilation Require that the model should be fully
compiled to TensorRT or throw an error
--teo=[torch-executed-ops...],
--torch-executed-ops=[torch-executed-ops...]
(Repeatable) Operator in the graph that
...