Replies: 3 comments
-
Couple notes:
|
Beta Was this translation helpful? Give feedback.
-
There may be additional information like the weight map cache that we should save with the TRT engine that would save the EngineCache time. See here: https://github.com/pytorch/TensorRT/pull/2983/files |
Beta Was this translation helpful? Give feedback.
-
@narendasan Currently, global timing cache (#2898) is used by default. Users seem unable to disable it. It is separated with engine caching and used during building TRT engines. Do you want to add controls or anything else? A simple use case would be perfect! |
Beta Was this translation helpful? Give feedback.
-
Engine Caching
Goal(s)
Boost performance while calling
torch.compile()
via reusing previously compiled TensorRT Engines rather than recompiling it every time, thereby avoiding recompilation time.Proposed APIs
The API would be invoked via argument to
torch_tensorrt.compile
, as so:If
ignore_engine_cache=False
(by default), the backend would attempt to retrieve previously saved TensorRT Engines on the disk. If there is a hit, reuse it rather than recompiling the model.If
ignore_engine_cache=True
, the backend would ignore saved TensorRT Engines anyway, instead, recompile the model and then save the new engine to the disk.This argument provides a layer of abstraction to the user, where the engine caching is handled by Torch-TensorRT and the acceleration benefits are immediate.
Design
Basically, there are four functions:
get_hash
,query
,save
, andload
. Their functionalities are described in the code as follows.The pipeline is as below:
We do the engine caching for sub-modules after the partition phase.
When querying whether there is a hit in the engine cache.
Implementation
Isomorphic graph
If we want to reuse a compiled graph, the first question comes to mind is how to determine two graphs are isomorphic, since we only reuse the old engines if they are same as the new one.
Considering that
refit
is used to reassign a new GraphModule's weights to old TRT engine, in theengine cache
, we can reuserefit
in this feature. Hence, we only care about the architecture of the GraphModule, ignoring its weights. This means whatever the weights are, if two GraphModules have the same architecture, they are considered the same GraphModules.Hash graph
Since we only need to hash the architecture of GraphModule, we get rid of the weights from GraphModule. In the implementation, all weights will be replaced by 0. Then, we reuse PyTorch
Inductor
'sFxGraphCachePickler
to hash the GraphModule.example code
Cache eviction
The Least Recently Used (LRU) algorithm will be used as the cache eviction strategy. We will preset a hard-disk size for the storage of TRT Engines. Users are able to change the size.
Cache structure
Beta Was this translation helpful? Give feedback.
All reactions