Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add cpu tracing #773

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions torch2trt/tests/test_cpu_tracing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import pytest
import torch
from torch2trt import torch2trt


def test_cpu_tracing():

model = torch.nn.Conv2d(3, 3, kernel_size=1)

data = torch.randn(1, 3, 32, 32)

model_trt = torch2trt(model, [data])

assert(hasattr(model_trt, 'engine'))
assert(model_trt.engine is not None)

data = torch.randn(1, 3, 32, 32)
assert(torch.allclose(model(data), model_trt(data), atol=1e-3, rtol=1e-3))


if __name__ == '__main__':
test_cpu_tracing()
16 changes: 14 additions & 2 deletions torch2trt/torch2trt.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,7 +509,8 @@ def add_inputs(self, torch_inputs, names=None, dynamic_axes=None):
shape=shape,
dtype=torch_dtype_to_trt(torch_input.dtype),
)
trt_tensor.location = torch_device_to_trt(torch_input.device)
# trt_tensor.location = torch_device_to_trt(torch_input.device)
trt_tensor.location = trt.TensorLocation.DEVICE
torch_input._trt = trt_tensor

def mark_outputs(self, torch_outputs, names=None):
Expand All @@ -520,7 +521,8 @@ def mark_outputs(self, torch_outputs, names=None):
for i, torch_output in enumerate(torch_outputs):
trt_tensor = torch_output._trt
trt_tensor.name = names[i]
trt_tensor.location = torch_device_to_trt(torch_output.device)
# trt_tensor.location = torch_device_to_trt(torch_output.device)
trt_tensor.location = trt.TensorLocation.DEVICE
trt_tensor.dtype = torch_dtype_to_trt(torch_output.dtype)
self.network.mark_output(trt_tensor)

Expand Down Expand Up @@ -576,9 +578,15 @@ def _load_from_state_dict(
def forward(self, *inputs):
bindings = [None] * (len(self.input_names) + len(self.output_names))

# flatten inputs
if self.input_flattener is not None:
inputs = self.input_flattener.flatten(inputs)

input_dtype = inputs[0].device

# place inputs on device
inputs = [t.cuda() for t in inputs]

for i, input_name in enumerate(self.input_names):
idx = self.engine.get_binding_index(input_name)
shape = tuple(inputs[i].shape)
Expand All @@ -600,6 +608,10 @@ def forward(self, *inputs):
bindings, torch.cuda.current_stream().cuda_stream
)

# map outputs to input dtype
outputs = [t.to(input_dtype) for t in outputs]

# unflatten outputs
if self.output_flattener is not None:
outputs = self.output_flattener.unflatten(outputs)
else:
Expand Down