Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integrate TensorRT Version 10 into TensorRT exporter, importer, and forward pass #1799

Merged
merged 2 commits into from
Jan 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 39 additions & 18 deletions boxmot/appearance/backends/tensorrt_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

class TensorRTBackend(BaseModelBackend):
def __init__(self, weights, device, half):
self.is_trt10 = False
super().__init__(weights, device, half)
self.nhwc = False
self.half = half
Expand Down Expand Up @@ -40,22 +41,37 @@ def load_model(self, w):
self.context = self.model_.create_execution_context()
self.bindings = OrderedDict()

self.is_trt10 = not hasattr(self.model_, "num_bindings")
num = range(self.model_.num_io_tensors) if self.is_trt10 else range(self.model_.num_bindings)

# Parse bindings
for index in range(self.model_.num_bindings):
name = self.model_.get_binding_name(index)
dtype = trt.nptype(self.model_.get_binding_dtype(index))
is_input = self.model_.binding_is_input(index)
for index in num:
if self.is_trt10:
name = self.model_.get_tensor_name(index)
dtype = trt.nptype(self.model_.get_tensor_dtype(name))
is_input = self.model_.get_tensor_mode(name) == trt.TensorIOMode.INPUT
if is_input and -1 in tuple(self.model_.get_tensor_shape(name)):
self.context.set_input_shape(name, tuple(self.model_.get_tensor_profile_shape(name, 0)[1]))
if is_input and dtype == np.float16:
self.fp16 = True

shape = tuple(self.context.get_tensor_shape(name))

# Handle dynamic shapes
if is_input and -1 in self.model_.get_binding_shape(index):
profile_index = 0
min_shape, opt_shape, max_shape = self.model_.get_profile_shape(profile_index, index)
self.context.set_binding_shape(index, opt_shape)
else:
name = self.model_.get_binding_name(index)
dtype = trt.nptype(self.model_.get_binding_dtype(index))
is_input = self.model_.binding_is_input(index)

if is_input and dtype == np.float16:
self.fp16 = True
# Handle dynamic shapes
if is_input and -1 in self.model_.get_binding_shape(index):
profile_index = 0
min_shape, opt_shape, max_shape = self.model_.get_profile_shape(profile_index, index)
self.context.set_binding_shape(index, opt_shape)

shape = tuple(self.context.get_binding_shape(index))
if is_input and dtype == np.float16:
self.fp16 = True

shape = tuple(self.context.get_binding_shape(index))
data = torch.from_numpy(np.empty(shape, dtype=dtype)).to(self.device)
self.bindings[name] = Binding(name, dtype, shape, data, int(data.data_ptr()))

Expand All @@ -64,12 +80,17 @@ def load_model(self, w):
def forward(self, im_batch):
# Adjust for dynamic shapes
if im_batch.shape != self.bindings["images"].shape:
i_in = self.model_.get_binding_index("images")
i_out = self.model_.get_binding_index("output")
self.context.set_binding_shape(i_in, im_batch.shape)
self.bindings["images"] = self.bindings["images"]._replace(shape=im_batch.shape)
output_shape = tuple(self.context.get_binding_shape(i_out))
self.bindings["output"].data.resize_(output_shape)
if self.is_trt10:
self.context.set_input_shape("images", im_batch.shape)
self.bindings["images"] = self.bindings["images"]._replace(shape=im_batch.shape)
self.bindings["output"].data.resize_(tuple(self.context.get_tensor_shape("output")))
else:
i_in = self.model_.get_binding_index("images")
i_out = self.model_.get_binding_index("output")
self.context.set_binding_shape(i_in, im_batch.shape)
self.bindings["images"] = self.bindings["images"]._replace(shape=im_batch.shape)
output_shape = tuple(self.context.get_binding_shape(i_out))
self.bindings["output"].data.resize_(output_shape)

s = self.bindings["images"].shape
assert im_batch.shape == s, f"Input size {im_batch.shape} does not match model size {s}"
Expand Down
14 changes: 11 additions & 3 deletions boxmot/appearance/exporters/tensorrt_exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def export(self):

onnx_file = self.export_onnx()
LOGGER.info(f"\nStarting export with TensorRT {trt.__version__}...")
is_trt10 = int(trt.__version__.split(".")[0]) >= 10 # is TensorRT >= 10
assert onnx_file.exists(), f"Failed to export ONNX file: {onnx_file}"
f = self.file.with_suffix(".engine")
logger = trt.Logger(trt.Logger.INFO)
Expand All @@ -27,7 +28,11 @@ def export(self):

builder = trt.Builder(logger)
config = builder.create_builder_config()
config.max_workspace_size = self.workspace * 1 << 30
workspace = int(self.workspace * (1 << 30))
if is_trt10:
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, workspace)
else: # TensorRT versions 7, 8
config.max_workspace_size = workspace

flag = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
network = builder.create_network(flag)
Expand Down Expand Up @@ -62,8 +67,11 @@ def export(self):
if builder.platform_has_fast_fp16 and self.half:
config.set_flag(trt.BuilderFlag.FP16)
config.default_device_type = trt.DeviceType.GPU
with builder.build_engine(network, config) as engine, open(f, "wb") as t:
t.write(engine.serialize())

build = builder.build_serialized_network if is_trt10 else builder.build_engine
with build(network, config) as engine, open(f, "wb") as t:
t.write(engine if is_trt10 else engine.serialize())

return f


Expand Down
Loading