diff --git a/boxmot/appearance/backends/tensorrt_backend.py b/boxmot/appearance/backends/tensorrt_backend.py index 32c840dbdb..18273daf75 100644 --- a/boxmot/appearance/backends/tensorrt_backend.py +++ b/boxmot/appearance/backends/tensorrt_backend.py @@ -7,6 +7,7 @@ class TensorRTBackend(BaseModelBackend): def __init__(self, weights, device, half): + self.is_trt10 = False super().__init__(weights, device, half) self.nhwc = False self.half = half @@ -40,22 +41,37 @@ def load_model(self, w): self.context = self.model_.create_execution_context() self.bindings = OrderedDict() + self.is_trt10 = not hasattr(self.model_, "num_bindings") + num = range(self.model_.num_io_tensors) if self.is_trt10 else range(self.model_.num_bindings) + # Parse bindings - for index in range(self.model_.num_bindings): - name = self.model_.get_binding_name(index) - dtype = trt.nptype(self.model_.get_binding_dtype(index)) - is_input = self.model_.binding_is_input(index) + for index in num: + if self.is_trt10: + name = self.model_.get_tensor_name(index) + dtype = trt.nptype(self.model_.get_tensor_dtype(name)) + is_input = self.model_.get_tensor_mode(name) == trt.TensorIOMode.INPUT + if is_input and -1 in tuple(self.model_.get_tensor_shape(name)): + self.context.set_input_shape(name, tuple(self.model_.get_tensor_profile_shape(name, 0)[1])) + if is_input and dtype == np.float16: + self.fp16 = True + + shape = tuple(self.context.get_tensor_shape(name)) - # Handle dynamic shapes - if is_input and -1 in self.model_.get_binding_shape(index): - profile_index = 0 - min_shape, opt_shape, max_shape = self.model_.get_profile_shape(profile_index, index) - self.context.set_binding_shape(index, opt_shape) + else: + name = self.model_.get_binding_name(index) + dtype = trt.nptype(self.model_.get_binding_dtype(index)) + is_input = self.model_.binding_is_input(index) - if is_input and dtype == np.float16: - self.fp16 = True + # Handle dynamic shapes + if is_input and -1 in self.model_.get_binding_shape(index): + profile_index = 0 + min_shape, opt_shape, max_shape = self.model_.get_profile_shape(profile_index, index) + self.context.set_binding_shape(index, opt_shape) - shape = tuple(self.context.get_binding_shape(index)) + if is_input and dtype == np.float16: + self.fp16 = True + + shape = tuple(self.context.get_binding_shape(index)) data = torch.from_numpy(np.empty(shape, dtype=dtype)).to(self.device) self.bindings[name] = Binding(name, dtype, shape, data, int(data.data_ptr())) @@ -64,12 +80,17 @@ def load_model(self, w): def forward(self, im_batch): # Adjust for dynamic shapes if im_batch.shape != self.bindings["images"].shape: - i_in = self.model_.get_binding_index("images") - i_out = self.model_.get_binding_index("output") - self.context.set_binding_shape(i_in, im_batch.shape) - self.bindings["images"] = self.bindings["images"]._replace(shape=im_batch.shape) - output_shape = tuple(self.context.get_binding_shape(i_out)) - self.bindings["output"].data.resize_(output_shape) + if self.is_trt10: + self.context.set_input_shape("images", im_batch.shape) + self.bindings["images"] = self.bindings["images"]._replace(shape=im_batch.shape) + self.bindings["output"].data.resize_(tuple(self.context.get_tensor_shape("output"))) + else: + i_in = self.model_.get_binding_index("images") + i_out = self.model_.get_binding_index("output") + self.context.set_binding_shape(i_in, im_batch.shape) + self.bindings["images"] = self.bindings["images"]._replace(shape=im_batch.shape) + output_shape = tuple(self.context.get_binding_shape(i_out)) + self.bindings["output"].data.resize_(output_shape) s = self.bindings["images"].shape assert im_batch.shape == s, f"Input size {im_batch.shape} does not match model size {s}" diff --git a/boxmot/appearance/exporters/tensorrt_exporter.py b/boxmot/appearance/exporters/tensorrt_exporter.py index ba8e341597..7f138da79b 100644 --- a/boxmot/appearance/exporters/tensorrt_exporter.py +++ b/boxmot/appearance/exporters/tensorrt_exporter.py @@ -19,6 +19,7 @@ def export(self): onnx_file = self.export_onnx() LOGGER.info(f"\nStarting export with TensorRT {trt.__version__}...") + is_trt10 = int(trt.__version__.split(".")[0]) >= 10 # is TensorRT >= 10 assert onnx_file.exists(), f"Failed to export ONNX file: {onnx_file}" f = self.file.with_suffix(".engine") logger = trt.Logger(trt.Logger.INFO) @@ -27,7 +28,11 @@ def export(self): builder = trt.Builder(logger) config = builder.create_builder_config() - config.max_workspace_size = self.workspace * 1 << 30 + workspace = int(self.workspace * (1 << 30)) + if is_trt10: + config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, workspace) + else: # TensorRT versions 7, 8 + config.max_workspace_size = workspace flag = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) network = builder.create_network(flag) @@ -62,8 +67,11 @@ def export(self): if builder.platform_has_fast_fp16 and self.half: config.set_flag(trt.BuilderFlag.FP16) config.default_device_type = trt.DeviceType.GPU - with builder.build_engine(network, config) as engine, open(f, "wb") as t: - t.write(engine.serialize()) + + build = builder.build_serialized_network if is_trt10 else builder.build_engine + with build(network, config) as engine, open(f, "wb") as t: + t.write(engine if is_trt10 else engine.serialize()) + return f