Skip to content

Commit

Permalink
moved backend and device selection out of main function
Browse files Browse the repository at this point in the history
  • Loading branch information
davidpagnon committed Dec 9, 2024
1 parent 44301a1 commit c8d7971
Showing 1 changed file with 27 additions and 11 deletions.
38 changes: 27 additions & 11 deletions Sports2D/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,21 +221,18 @@ def setup_video(video_file_path, save_vid, vid_output_path):
return cap, out_vid, cam_width, cam_height, fps


def setup_pose_tracker(det_frequency, mode, tracking):
def setup_backend_device():
'''
Set up the RTMLib pose tracker with the appropriate model and backend.
If CUDA is available, use it with ONNXRuntime backend; else use CPU with openvino
INPUTS:
- det_frequency: int. The frequency of pose detection (every N frames)
- mode: str. The mode of the pose tracker ('lightweight', 'balanced', 'performance')
- tracking: bool. Whether to track persons across frames with RTMlib tracker
Set up the backend and device for the pose tracker based on the availability of hardware acceleration.
TensorRT is not supported by RTMLib yet: https://github.com/Tau-J/rtmlib/issues/12
OUTPUTS:
- pose_tracker: PoseTracker. The initialized pose tracker object
Selects the best option in the following order of priority:
1. GPU with CUDA and ONNXRuntime backend (if CUDAExecutionProvider is available)
2. GPU with ROCm and ONNXRuntime backend (if ROCMExecutionProvider is available, for AMD GPUs)
3. GPU with MPS or CoreML and ONNXRuntime backend (for macOS systems)
4. CPU with OpenVINO backend (default fallback)
'''

# If CUDA is available, use it with ONNXRuntime backend; else use CPU with openvino
try:
import torch
import onnxruntime as ort
Expand All @@ -262,6 +259,25 @@ def setup_pose_tracker(det_frequency, mode, tracking):
device = 'cpu'
backend = 'openvino'
logging.info(f"\nNo valid CUDA installation found: using OpenVINO backend with CPU.")

return backend, device


def setup_pose_tracker(det_frequency, mode, tracking):
'''
Set up the RTMLib pose tracker with the appropriate model and backend.
If CUDA is available, use it with ONNXRuntime backend; else use CPU with openvino
INPUTS:
- det_frequency: int. The frequency of pose detection (every N frames)
- mode: str. The mode of the pose tracker ('lightweight', 'balanced', 'performance')
- tracking: bool. Whether to track persons across frames with RTMlib tracker
OUTPUTS:
- pose_tracker: PoseTracker. The initialized pose tracker object
'''

backend, device = setup_backend_device()

# Initialize the pose tracker with Halpe26 model
pose_tracker = PoseTracker(
Expand Down

0 comments on commit c8d7971

Please sign in to comment.