Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions core/runtime/TRTEngine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,14 @@ TRTEngine::TRTEngine(
out_binding_names[pyt_idx] = binding_name;
}
num_io = std::make_pair(inputs_size, outputs);

this->current_device_id = at::cuda::current_device();
this->stream = c10::cuda::getCurrentCUDAStream(this->current_device_id);
this->io_size = this->cuda_engine->getNbIOTensors();
for (int64_t i = 0; i < this->in_binding_names.size(); i++) {
this->isShapeInferenceIO[this->in_binding_names[i]] =
this->cuda_engine->isShapeInferenceIO(this->in_binding_names[i].c_str());
}
}

#ifndef NDEBUG
Expand Down Expand Up @@ -281,6 +289,14 @@ void TRTEngine::enable_profiling() {
exec_ctx->setProfiler(trt_engine_profiler.get());
}

void TRTEngine::set_unowned_output_tensor(bool enable) {
this->unowned_output_tensor = enable;
}

bool TRTEngine::is_unowned_output_tensor() {
return this->unowned_output_tensor;
}

void TRTEngine::set_profile_format(std::string format) {
if (format == "trex") {
this->trt_engine_profiler->set_profile_format(TraceFormat::kTREX);
Expand Down
10 changes: 8 additions & 2 deletions core/runtime/TRTEngine.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,9 @@ struct TRTEngine : torch::CustomClassHolder {
std::shared_ptr<nvinfer1::ICudaEngine> cuda_engine;
std::shared_ptr<nvinfer1::IExecutionContext> exec_ctx;
std::pair<uint64_t, uint64_t> num_io;
uint64_t io_size;
std::map<std::string, bool> isShapeInferenceIO;
bool unowned_output_tensor = false;
std::string name;
RTDevice device_info;

Expand Down Expand Up @@ -159,6 +162,8 @@ struct TRTEngine : torch::CustomClassHolder {
int64_t get_automatic_device_memory_budget();
std::vector<at::Tensor> infer_outputs(std::vector<std::vector<int64_t>> input_shapes);
void set_pre_allocated_outputs(bool enable);
void set_unowned_output_tensor(bool enable);
bool is_unowned_output_tensor();
TorchTRTRuntimeStates runtime_states;
friend std::ostream& operator<<(std::ostream& os, const TRTEngine& engine);
static const char BINDING_DELIM = '%';
Expand All @@ -169,13 +174,14 @@ struct TRTEngine : torch::CustomClassHolder {

// CUDAGraph-Related Functionality
at::cuda::CUDAGraph cudagraph = {};
at::cuda::CUDAStream engine_stream = c10::cuda::getDefaultCUDAStream();
at::cuda::CUDAStream caller_stream = c10::cuda::getDefaultCUDAStream();
at::cuda::CUDAStream stream = c10::cuda::getDefaultCUDAStream();
int64_t current_device_id = at::cuda::current_device();
std::vector<at::Tensor> input_buffers = {};
std::vector<at::Tensor> output_buffers = {};
std::string shape_key = "None";
bool use_pre_allocated_outputs = false;
std::vector<at::Tensor> pre_allocated_outputs;
std::vector<at::Tensor> allocated_outputs;

// Output Allocator-Related Functionality
bool requires_output_allocator = false; // engine requires output allocator
Expand Down
113 changes: 44 additions & 69 deletions core/runtime/execute_engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,8 @@ void setup_input_tensors(
std::vector<at::Tensor> inputs,
c10::intrusive_ptr<TRTEngine> compiled_engine,
bool cudagraphs_enabled,
bool need_cudagraphs_record) {
bool need_cudagraphs_record,
bool shape_changed) {
// this is a buffer to store shape tensor input addresses throughout the runtime scope
std::list<std::vector<int64_t>> inputShapeTensorValues;
std::list<at::Tensor> formatted_inputs(compiled_engine->num_io.first);
Expand All @@ -117,7 +118,7 @@ void setup_input_tensors(
auto shape = core::util::toVec(dims);
LOG_DEBUG("Input Name: " << name << " Shape: " << dims);

if (compiled_engine->cuda_engine->isShapeInferenceIO(name.c_str())) {
if (compiled_engine->isShapeInferenceIO[name]) {
// Shape tensor inputs are casted to int64 explicitly.
// Refer to
// https://github.com/NVIDIA/TensorRT/blob/d2f4ef789a9a6ffdf37b55c3f81b486225f6b380/samples/common/sampleInference.cpp#L435
Expand Down Expand Up @@ -145,10 +146,10 @@ void setup_input_tensors(
// Create a new persistent input buffer
compiled_engine->input_buffers[i] = std::move(formatted_inputs.back().clone());
}

TORCHTRT_CHECK(
compiled_engine->exec_ctx->setInputShape(name.c_str(), dims), "Error while setting the input shape");

if (shape_changed) {
TORCHTRT_CHECK(
compiled_engine->exec_ctx->setInputShape(name.c_str(), dims), "Error while setting the input shape");
}
if (cudagraphs_enabled) {
// If using CUDAGraphs copy formatted input to the corresponding persistent input buffer
compiled_engine->input_buffers[i].copy_(formatted_inputs.back(), true);
Expand Down Expand Up @@ -217,7 +218,7 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
compiled_engine->cudagraph.reset();
}

std::vector<at::Tensor> outputs(compiled_engine->num_io.second);
std::vector<at::Tensor> outputs;

// Intialize inputs and outputs to be available throughout the succeeding scopes
{ // Input Setup
Expand All @@ -226,10 +227,9 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
input_profiler_guard =
std::make_unique<torch::autograd::profiler::RecordProfile>(compiled_engine->input_profile_path);
}

setup_input_tensors(inputs, compiled_engine, cudagraphs_enabled, need_cudagraphs_record);
setup_input_tensors(inputs, compiled_engine, cudagraphs_enabled, need_cudagraphs_record, shape_changed);
// Check if input shapes can be inferred.
int32_t const io_size{compiled_engine->cuda_engine->getNbIOTensors()};
int32_t const io_size{compiled_engine->io_size};
std::vector<char const*> names(io_size);
int32_t const nbNames = compiled_engine->exec_ctx->inferShapes(names.size(), names.data());
TORCHTRT_CHECK(
Expand All @@ -240,6 +240,7 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
}

{ // Output Setup
bool new_outputs = false;
std::unique_ptr<torch::autograd::profiler::RecordProfile> output_profiler_guard;
if (compiled_engine->profile_execution) {
output_profiler_guard =
Expand All @@ -248,64 +249,59 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
if (can_use_pre_allocated_outputs) {
outputs = compiled_engine->pre_allocated_outputs;
} else {
outputs = create_output_tensors(compiled_engine);
if (compiled_engine->allocated_outputs.size() == 0 or compiled_engine->unowned_output_tensor or shape_changed) {
compiled_engine->allocated_outputs = create_output_tensors(compiled_engine);
new_outputs = true;
}
outputs = compiled_engine->allocated_outputs;
}

for (auto output_indices : compiled_engine->out_binding_map) {
auto pyt_idx = output_indices.second;
std::string name = compiled_engine->out_binding_names[pyt_idx];
if (need_cudagraphs_record) {
// If we are recording the cuda graph then we need to update the persistent output buffer
compiled_engine->output_buffers[pyt_idx] = std::move(outputs[pyt_idx].clone());
}
if (new_outputs) {
for (auto output_indices : compiled_engine->out_binding_map) {
auto pyt_idx = output_indices.second;
std::string name = compiled_engine->out_binding_names[pyt_idx];
if (need_cudagraphs_record) {
// If we are recording the cuda graph then we need to update the persistent output buffer
compiled_engine->output_buffers[pyt_idx] = std::move(outputs[pyt_idx].clone());
}

if (cudagraphs_enabled) {
TORCHTRT_CHECK(
compiled_engine->exec_ctx->setTensorAddress(
name.c_str(), compiled_engine->output_buffers[pyt_idx].data_ptr()),
"Error while setting the output tensor address");
} else {
TORCHTRT_CHECK(
compiled_engine->exec_ctx->setTensorAddress(name.c_str(), outputs[pyt_idx].data_ptr()),
"Error while setting the output tensor address");
if (cudagraphs_enabled) {
TORCHTRT_CHECK(
compiled_engine->exec_ctx->setTensorAddress(
name.c_str(), compiled_engine->output_buffers[pyt_idx].data_ptr()),
"Error while setting the output tensor address");
} else {
TORCHTRT_CHECK(
compiled_engine->exec_ctx->setTensorAddress(name.c_str(), outputs[pyt_idx].data_ptr()),
"Error while setting the output tensor address");
}
}
}
}

auto current_device_id = -1;
if (inputs.size() > 0) {
current_device_id = inputs[0].device().index(); // Done this way to avoid a call to cudart
} else if (outputs.size() > 0) {
current_device_id = outputs[0].device().index(); // Done this way to avoid a call to cudart
}

compiled_engine->caller_stream = c10::cuda::getCurrentCUDAStream(current_device_id);
if (compiled_engine->engine_stream == c10::cuda::getDefaultCUDAStream(current_device_id)) {
// Create a new stream if the engine stream is the default stream
compiled_engine->engine_stream = c10::cuda::getStreamFromPool(false, current_device_id);
if (current_device_id != compiled_engine->current_device_id) {
compiled_engine->stream = c10::cuda::getCurrentCUDAStream(current_device_id);
}
}

{ // Engine Execution (execute on engine stream)
c10::cuda::CUDAStreamGuard stream_guard(compiled_engine->engine_stream);

std::unique_ptr<torch::autograd::profiler::RecordProfile> enqueue_profiler_guard;
if (compiled_engine->profile_execution) {
enqueue_profiler_guard =
std::make_unique<torch::autograd::profiler::RecordProfile>(compiled_engine->enqueue_profile_path);
}

// Block engine stream until results are available on caller stream
at::cuda::CUDAEvent caller_exec_complete;
caller_exec_complete.record(compiled_engine->caller_stream);
caller_exec_complete.block(compiled_engine->engine_stream);

if (!cudagraphs_enabled) {
// Direct execution uses the caller buffers directly
compiled_engine->exec_ctx->enqueueV3(compiled_engine->engine_stream);
compiled_engine->exec_ctx->enqueueV3(compiled_engine->stream);
} else {
if (need_cudagraphs_record) {
// If cudagraphs needs to record a graph, capture the enqueueV3 call in a graph
c10::cuda::CUDAStream recording_stream = compiled_engine->engine_stream;
c10::cuda::CUDAStream recording_stream = compiled_engine->stream;
compiled_engine->cudagraph.capture_begin();
compiled_engine->exec_ctx->enqueueV3(recording_stream);
compiled_engine->cudagraph.capture_end();
Expand All @@ -325,11 +321,6 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
compiled_engine->pre_allocated_outputs = create_output_tensors(compiled_engine);
}

// Block caller stream until engine execution is complete
at::cuda::CUDAEvent trt_exec_complete;
trt_exec_complete.record(compiled_engine->engine_stream);
trt_exec_complete.block(compiled_engine->caller_stream);

if (cudagraphs_enabled) {
// If in CUDAGraph mode, results need to be copied to the result buffers (on caller stream)
for (size_t o = 0; o < compiled_engine->output_buffers.size(); o++) {
Expand All @@ -354,7 +345,7 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
std::make_unique<torch::autograd::profiler::RecordProfile>(compiled_engine->input_profile_path);
}

setup_input_tensors(inputs, compiled_engine, false, false);
setup_input_tensors(inputs, compiled_engine, false, false, true);
// Check if input shapes can be inferred.
int32_t const io_size{compiled_engine->cuda_engine->getNbIOTensors()};
std::vector<char const*> names(io_size);
Expand All @@ -378,40 +369,24 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
auto current_device_id = -1;
if (inputs.size() > 0) {
current_device_id = inputs[0].device().index(); // Done this way to avoid a call to cudart
} else {
current_device_id = at::cuda::current_device();
}

compiled_engine->caller_stream = c10::cuda::getCurrentCUDAStream(current_device_id);
if (compiled_engine->engine_stream == c10::cuda::getDefaultCUDAStream(current_device_id)) {
// Create a new stream if the engine stream is the default stream
compiled_engine->engine_stream = c10::cuda::getStreamFromPool(false, current_device_id);
if (current_device_id != compiled_engine->current_device_id) {
compiled_engine->stream = c10::cuda::getCurrentCUDAStream(current_device_id);
}
}

{ // Engine Execution (execute on engine stream)
c10::cuda::CUDAStreamGuard stream_guard(compiled_engine->engine_stream);

std::unique_ptr<torch::autograd::profiler::RecordProfile> enqueue_profiler_guard;
if (compiled_engine->profile_execution) {
enqueue_profiler_guard =
std::make_unique<torch::autograd::profiler::RecordProfile>(compiled_engine->enqueue_profile_path);
}

// Block engine stream until results are available on caller stream
at::cuda::CUDAEvent caller_exec_complete;
caller_exec_complete.record(compiled_engine->caller_stream);
caller_exec_complete.block(compiled_engine->engine_stream);

// Direct execution uses the caller buffers directly
compiled_engine->exec_ctx->enqueueV3(compiled_engine->engine_stream);
compiled_engine->exec_ctx->enqueueV3(compiled_engine->stream);

} // End engine exeuction (resets to caller stream)

// Block caller stream until engine execution is complete
at::cuda::CUDAEvent trt_exec_complete;
trt_exec_complete.record(compiled_engine->engine_stream);
trt_exec_complete.block(compiled_engine->caller_stream);

std::unique_ptr<torch::autograd::profiler::RecordProfile> output_profiler_guard;
if (compiled_engine->profile_execution) {
output_profiler_guard =
Expand Down
2 changes: 2 additions & 0 deletions core/runtime/register_jit_hooks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,8 @@ static auto TORCHTRT_UNUSED TRTEngineTSRegistrtion =
.def("get_engine_layer_info", &TRTEngine::get_engine_layer_info)
.def("infer_outputs", &TRTEngine::infer_outputs)
.def("reset_captured_graph", &TRTEngine::reset_captured_graph)
.def("set_unowned_output_tensor", &TRTEngine::set_unowned_output_tensor)
.def("is_unowned_output_tensor", &TRTEngine::is_unowned_output_tensor)
.def_readwrite("use_pre_allocated_outputs", &TRTEngine::use_pre_allocated_outputs)
.def_readwrite("use_output_allocator_outputs", &TRTEngine::use_output_allocator_outputs)
.def_property(
Expand Down
6 changes: 5 additions & 1 deletion py/torch_tensorrt/dynamo/_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -873,7 +873,7 @@ def preserve_module_specs(
trt_modules = {}
# Iterate over all components that can be accelerated
# Generate the corresponding TRT Module for those

trt_module = None
for name, _ in partitioned_module.named_children():
submodule = getattr(partitioned_module, name)
# filter on the GraphModule
Expand Down Expand Up @@ -994,6 +994,10 @@ def preserve_module_specs(
) as f:
f.write(trt_module.get_layer_info())

# Only set the requires_unique_output flag for the last TRT Module when user has access to the output tensor
if trt_module:
trt_module.set_unowned_output_tensor(True)

# Parse the graph I/O and store it in dryrun tracker
parse_graph_io(gm, dryrun_tracker)

Expand Down
Loading