-
Notifications
You must be signed in to change notification settings - Fork 376
Graph break overhead #3946
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Graph break overhead #3946
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -96,7 +96,8 @@ void setup_input_tensors( | |
| std::vector<at::Tensor> inputs, | ||
| c10::intrusive_ptr<TRTEngine> compiled_engine, | ||
| bool cudagraphs_enabled, | ||
| bool need_cudagraphs_record) { | ||
| bool need_cudagraphs_record, | ||
| bool shape_changed) { | ||
| // this is a buffer to store shape tensor input addresses throughout the runtime scope | ||
| std::list<std::vector<int64_t>> inputShapeTensorValues; | ||
| std::list<at::Tensor> formatted_inputs(compiled_engine->num_io.first); | ||
|
|
@@ -117,7 +118,7 @@ void setup_input_tensors( | |
| auto shape = core::util::toVec(dims); | ||
| LOG_DEBUG("Input Name: " << name << " Shape: " << dims); | ||
|
|
||
| if (compiled_engine->cuda_engine->isShapeInferenceIO(name.c_str())) { | ||
| if (compiled_engine->isShapeInferenceIO[name]) { | ||
| // Shape tensor inputs are casted to int64 explicitly. | ||
| // Refer to | ||
| // https://github.com/NVIDIA/TensorRT/blob/d2f4ef789a9a6ffdf37b55c3f81b486225f6b380/samples/common/sampleInference.cpp#L435 | ||
|
|
@@ -145,10 +146,10 @@ void setup_input_tensors( | |
| // Create a new persistent input buffer | ||
| compiled_engine->input_buffers[i] = std::move(formatted_inputs.back().clone()); | ||
| } | ||
|
|
||
| TORCHTRT_CHECK( | ||
| compiled_engine->exec_ctx->setInputShape(name.c_str(), dims), "Error while setting the input shape"); | ||
|
|
||
| if (shape_changed) { | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can't we use the same shape keys we used for cudagraphs? Why are we implementing another system? |
||
| TORCHTRT_CHECK( | ||
| compiled_engine->exec_ctx->setInputShape(name.c_str(), dims), "Error while setting the input shape"); | ||
| } | ||
| if (cudagraphs_enabled) { | ||
| // If using CUDAGraphs copy formatted input to the corresponding persistent input buffer | ||
| compiled_engine->input_buffers[i].copy_(formatted_inputs.back(), true); | ||
|
|
@@ -217,7 +218,7 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr | |
| compiled_engine->cudagraph.reset(); | ||
| } | ||
|
|
||
| std::vector<at::Tensor> outputs(compiled_engine->num_io.second); | ||
| std::vector<at::Tensor> outputs; | ||
|
|
||
| // Intialize inputs and outputs to be available throughout the succeeding scopes | ||
| { // Input Setup | ||
|
|
@@ -226,10 +227,9 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr | |
| input_profiler_guard = | ||
| std::make_unique<torch::autograd::profiler::RecordProfile>(compiled_engine->input_profile_path); | ||
| } | ||
|
|
||
| setup_input_tensors(inputs, compiled_engine, cudagraphs_enabled, need_cudagraphs_record); | ||
| setup_input_tensors(inputs, compiled_engine, cudagraphs_enabled, need_cudagraphs_record, shape_changed); | ||
| // Check if input shapes can be inferred. | ||
| int32_t const io_size{compiled_engine->cuda_engine->getNbIOTensors()}; | ||
| int32_t const io_size{compiled_engine->io_size}; | ||
| std::vector<char const*> names(io_size); | ||
| int32_t const nbNames = compiled_engine->exec_ctx->inferShapes(names.size(), names.data()); | ||
| TORCHTRT_CHECK( | ||
|
|
@@ -240,6 +240,7 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr | |
| } | ||
|
|
||
| { // Output Setup | ||
| bool new_outputs = false; | ||
| std::unique_ptr<torch::autograd::profiler::RecordProfile> output_profiler_guard; | ||
| if (compiled_engine->profile_execution) { | ||
| output_profiler_guard = | ||
|
|
@@ -248,64 +249,58 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr | |
| if (can_use_pre_allocated_outputs) { | ||
| outputs = compiled_engine->pre_allocated_outputs; | ||
| } else { | ||
| outputs = create_output_tensors(compiled_engine); | ||
| if (compiled_engine->allocated_outputs.size() == 0 or compiled_engine->output_tensors_are_unowned or | ||
| shape_changed) { | ||
| compiled_engine->allocated_outputs = create_output_tensors(compiled_engine); | ||
| new_outputs = true; | ||
| } | ||
| outputs = compiled_engine->allocated_outputs; | ||
| } | ||
|
|
||
| for (auto output_indices : compiled_engine->out_binding_map) { | ||
| auto pyt_idx = output_indices.second; | ||
| std::string name = compiled_engine->out_binding_names[pyt_idx]; | ||
| if (need_cudagraphs_record) { | ||
| // If we are recording the cuda graph then we need to update the persistent output buffer | ||
| compiled_engine->output_buffers[pyt_idx] = std::move(outputs[pyt_idx].clone()); | ||
| } | ||
| if (new_outputs) { | ||
| for (auto output_indices : compiled_engine->out_binding_map) { | ||
| auto pyt_idx = output_indices.second; | ||
| std::string name = compiled_engine->out_binding_names[pyt_idx]; | ||
| if (need_cudagraphs_record) { | ||
| // If we are recording the cuda graph then we need to update the persistent output buffer | ||
| compiled_engine->output_buffers[pyt_idx] = std::move(outputs[pyt_idx].clone()); | ||
| } | ||
|
|
||
| if (cudagraphs_enabled) { | ||
| TORCHTRT_CHECK( | ||
| compiled_engine->exec_ctx->setTensorAddress( | ||
| name.c_str(), compiled_engine->output_buffers[pyt_idx].data_ptr()), | ||
| "Error while setting the output tensor address"); | ||
| } else { | ||
| TORCHTRT_CHECK( | ||
| compiled_engine->exec_ctx->setTensorAddress(name.c_str(), outputs[pyt_idx].data_ptr()), | ||
| "Error while setting the output tensor address"); | ||
| if (cudagraphs_enabled) { | ||
| TORCHTRT_CHECK( | ||
| compiled_engine->exec_ctx->setTensorAddress( | ||
| name.c_str(), compiled_engine->output_buffers[pyt_idx].data_ptr()), | ||
| "Error while setting the output tensor address"); | ||
| } else { | ||
| TORCHTRT_CHECK( | ||
| compiled_engine->exec_ctx->setTensorAddress(name.c_str(), outputs[pyt_idx].data_ptr()), | ||
| "Error while setting the output tensor address"); | ||
| } | ||
| } | ||
| } | ||
| } | ||
|
|
||
| auto current_device_id = -1; | ||
| if (inputs.size() > 0) { | ||
| current_device_id = inputs[0].device().index(); // Done this way to avoid a call to cudart | ||
| } else if (outputs.size() > 0) { | ||
| current_device_id = outputs[0].device().index(); // Done this way to avoid a call to cudart | ||
| } | ||
|
|
||
| compiled_engine->caller_stream = c10::cuda::getCurrentCUDAStream(current_device_id); | ||
| if (compiled_engine->engine_stream == c10::cuda::getDefaultCUDAStream(current_device_id)) { | ||
| // Create a new stream if the engine stream is the default stream | ||
| compiled_engine->engine_stream = c10::cuda::getStreamFromPool(false, current_device_id); | ||
| compiled_engine->stream = c10::cuda::getCurrentCUDAStream(current_device_id); | ||
| } | ||
|
|
||
| { // Engine Execution (execute on engine stream) | ||
| c10::cuda::CUDAStreamGuard stream_guard(compiled_engine->engine_stream); | ||
|
|
||
| std::unique_ptr<torch::autograd::profiler::RecordProfile> enqueue_profiler_guard; | ||
| if (compiled_engine->profile_execution) { | ||
| enqueue_profiler_guard = | ||
| std::make_unique<torch::autograd::profiler::RecordProfile>(compiled_engine->enqueue_profile_path); | ||
| } | ||
|
|
||
| // Block engine stream until results are available on caller stream | ||
| at::cuda::CUDAEvent caller_exec_complete; | ||
| caller_exec_complete.record(compiled_engine->caller_stream); | ||
| caller_exec_complete.block(compiled_engine->engine_stream); | ||
|
|
||
| if (!cudagraphs_enabled) { | ||
| // Direct execution uses the caller buffers directly | ||
| compiled_engine->exec_ctx->enqueueV3(compiled_engine->engine_stream); | ||
| compiled_engine->exec_ctx->enqueueV3(compiled_engine->stream); | ||
| } else { | ||
| if (need_cudagraphs_record) { | ||
| // If cudagraphs needs to record a graph, capture the enqueueV3 call in a graph | ||
| c10::cuda::CUDAStream recording_stream = compiled_engine->engine_stream; | ||
| c10::cuda::CUDAStream recording_stream = compiled_engine->stream; | ||
| compiled_engine->cudagraph.capture_begin(); | ||
| compiled_engine->exec_ctx->enqueueV3(recording_stream); | ||
| compiled_engine->cudagraph.capture_end(); | ||
|
|
@@ -325,11 +320,6 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr | |
| compiled_engine->pre_allocated_outputs = create_output_tensors(compiled_engine); | ||
| } | ||
|
|
||
| // Block caller stream until engine execution is complete | ||
| at::cuda::CUDAEvent trt_exec_complete; | ||
| trt_exec_complete.record(compiled_engine->engine_stream); | ||
| trt_exec_complete.block(compiled_engine->caller_stream); | ||
|
|
||
| if (cudagraphs_enabled) { | ||
| // If in CUDAGraph mode, results need to be copied to the result buffers (on caller stream) | ||
| for (size_t o = 0; o < compiled_engine->output_buffers.size(); o++) { | ||
|
|
@@ -354,7 +344,7 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr | |
| std::make_unique<torch::autograd::profiler::RecordProfile>(compiled_engine->input_profile_path); | ||
| } | ||
|
|
||
| setup_input_tensors(inputs, compiled_engine, false, false); | ||
| setup_input_tensors(inputs, compiled_engine, false, false, true); | ||
| // Check if input shapes can be inferred. | ||
| int32_t const io_size{compiled_engine->cuda_engine->getNbIOTensors()}; | ||
| std::vector<char const*> names(io_size); | ||
|
|
@@ -378,40 +368,22 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr | |
| auto current_device_id = -1; | ||
| if (inputs.size() > 0) { | ||
| current_device_id = inputs[0].device().index(); // Done this way to avoid a call to cudart | ||
| } else { | ||
| current_device_id = at::cuda::current_device(); | ||
| } | ||
|
|
||
| compiled_engine->caller_stream = c10::cuda::getCurrentCUDAStream(current_device_id); | ||
| if (compiled_engine->engine_stream == c10::cuda::getDefaultCUDAStream(current_device_id)) { | ||
| // Create a new stream if the engine stream is the default stream | ||
| compiled_engine->engine_stream = c10::cuda::getStreamFromPool(false, current_device_id); | ||
| compiled_engine->stream = c10::cuda::getCurrentCUDAStream(current_device_id); | ||
| } | ||
|
|
||
| { // Engine Execution (execute on engine stream) | ||
| c10::cuda::CUDAStreamGuard stream_guard(compiled_engine->engine_stream); | ||
|
|
||
| std::unique_ptr<torch::autograd::profiler::RecordProfile> enqueue_profiler_guard; | ||
| if (compiled_engine->profile_execution) { | ||
| enqueue_profiler_guard = | ||
| std::make_unique<torch::autograd::profiler::RecordProfile>(compiled_engine->enqueue_profile_path); | ||
| } | ||
|
|
||
| // Block engine stream until results are available on caller stream | ||
| at::cuda::CUDAEvent caller_exec_complete; | ||
| caller_exec_complete.record(compiled_engine->caller_stream); | ||
| caller_exec_complete.block(compiled_engine->engine_stream); | ||
|
|
||
| // Direct execution uses the caller buffers directly | ||
| compiled_engine->exec_ctx->enqueueV3(compiled_engine->engine_stream); | ||
| compiled_engine->exec_ctx->enqueueV3(compiled_engine->stream); | ||
|
|
||
| } // End engine exeuction (resets to caller stream) | ||
|
|
||
| // Block caller stream until engine execution is complete | ||
| at::cuda::CUDAEvent trt_exec_complete; | ||
| trt_exec_complete.record(compiled_engine->engine_stream); | ||
| trt_exec_complete.block(compiled_engine->caller_stream); | ||
|
|
||
| std::unique_ptr<torch::autograd::profiler::RecordProfile> output_profiler_guard; | ||
| if (compiled_engine->profile_execution) { | ||
| output_profiler_guard = | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -43,6 +43,7 @@ | |
| from torch_tensorrt.dynamo.partitioning._resource_partitioner import ( | ||
| resource_partition, | ||
| ) | ||
| from torch_tensorrt.dynamo.runtime._stream_handler import handle_cuda_stream | ||
| from torch_tensorrt.dynamo.utils import ( | ||
| deallocate_module, | ||
| get_cpu_memory_usage, | ||
|
|
@@ -557,7 +558,7 @@ def compile( | |
| stacklevel=2, | ||
| ) | ||
|
|
||
| if kwargs.get("use_explicit_typing", False) == False: | ||
| if not kwargs.get("use_explicit_typing", False): | ||
| warnings.warn( | ||
| "`use_explicit_typing` is deprecated. This setting will be removed and you should enable autocast instead.", | ||
| DeprecationWarning, | ||
|
|
@@ -949,6 +950,7 @@ def preserve_module_specs( | |
| for attr in dir(gm): | ||
| if attr.startswith("_frozen_param"): | ||
| delattr(gm, attr) | ||
| trt_module = None | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What is this? Why does this variable need to be outside the of scope of iterating through the module? |
||
|
|
||
| for name, _ in partitioned_module.named_children(): | ||
| submodule = getattr(partitioned_module, name) | ||
|
|
@@ -1070,20 +1072,28 @@ def preserve_module_specs( | |
| ) as f: | ||
| f.write(trt_module.get_layer_info()) | ||
|
|
||
| # Only set the requires_unique_output flag for the last TRT Module when user has access to the output tensor | ||
|
|
||
| # Parse the graph I/O and store it in dryrun tracker | ||
| parse_graph_io(gm, dryrun_tracker) | ||
|
|
||
| # Replace all FX Modules with TRT Modules | ||
| for name, trt_module in trt_modules.items(): | ||
| setattr(partitioned_module, name, trt_module) | ||
| if settings.lazy_engine_init and not settings.enable_cross_compile_for_windows: | ||
| getattr(partitioned_module, name).setup_engine() | ||
| trt_module = getattr(partitioned_module, name) | ||
| trt_module.setup_engine() | ||
|
|
||
| if trt_module: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I feel like this is a pretty round about way to get the last engine
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also the order from items is not deterministic so I dont think this is guaranteed to give you the last module anyway. You might need to work back from outputs. |
||
| trt_module.set_output_tensors_as_unowned(True) | ||
|
|
||
| # Reset settings object to user specification after fallback to global partitioning mode | ||
| if fast_partitioner_failed: | ||
| settings.use_fast_partitioner = True | ||
|
|
||
| dryrun_stats_display(dryrun_tracker, settings.dryrun) | ||
| if not settings.dryrun: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why is this conditioned on dry run? Dry run should reflect the deployment graph |
||
| handle_cuda_stream(partitioned_module) | ||
|
|
||
| return partitioned_module | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this PR supposed to implement the stream operator?