|
17 | 17 | ## INIT
|
18 | 18 | import re
|
19 | 19 | import sys
|
| 20 | +import cv2 |
20 | 21 | import subprocess
|
| 22 | +import logging |
21 | 23 | from pathlib import Path
|
22 | 24 |
|
23 | 25 | import numpy as np
|
24 | 26 | from scipy import interpolate
|
25 | 27 | import imageio_ffmpeg as ffmpeg
|
| 28 | +from rtmlib import PoseTracker, BodyWithFeet, Wholebody, Body |
26 | 29 |
|
27 | 30 | from PyQt5.QtWidgets import QMainWindow, QApplication, QWidget, QTabWidget, QVBoxLayout
|
28 | 31 | from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas
|
@@ -257,3 +260,172 @@ def euclidean_distance(q1, q2):
|
257 | 260 | euc_dist = np.sqrt(np.sum( [d**2 for d in dist]))
|
258 | 261 |
|
259 | 262 | return euc_dist
|
| 263 | + |
| 264 | +def setup_pose_tracker(det_frequency, mode, pose_model = "HALPE_26"): |
| 265 | + ''' |
| 266 | + Set up the RTMLib pose tracker with the appropriate model and backend. |
| 267 | + If CUDA is available, use it with ONNXRuntime backend; else use CPU with openvino |
| 268 | +
|
| 269 | + INPUTS: |
| 270 | + - det_frequency: int. The frequency of pose detection (every N frames) |
| 271 | + - mode: str. The mode of the pose tracker ('lightweight', 'balanced', 'performance') |
| 272 | + - tracking: bool. Whether to track persons across frames with RTMlib tracker |
| 273 | +
|
| 274 | + OUTPUTS: |
| 275 | + - pose_tracker: PoseTracker. The initialized pose tracker object |
| 276 | + ''' |
| 277 | + |
| 278 | + # If CUDA is available, use it with ONNXRuntime backend; else use CPU with openvino |
| 279 | + try: |
| 280 | + import torch |
| 281 | + import onnxruntime as ort |
| 282 | + if torch.cuda.is_available() == True and 'CUDAExecutionProvider' in ort.get_available_providers(): |
| 283 | + device = 'cuda' |
| 284 | + backend = 'onnxruntime' |
| 285 | + logging.info(f"\nValid CUDA installation found: using ONNXRuntime backend with GPU.") |
| 286 | + elif torch.cuda.is_available() == True and 'ROCMExecutionProvider' in ort.get_available_providers(): |
| 287 | + device = 'rocm' |
| 288 | + backend = 'onnxruntime' |
| 289 | + logging.info(f"\nValid ROCM installation found: using ONNXRuntime backend with GPU.") |
| 290 | + else: |
| 291 | + raise |
| 292 | + except: |
| 293 | + try: |
| 294 | + import onnxruntime as ort |
| 295 | + if 'MPSExecutionProvider' in ort.get_available_providers() or 'CoreMLExecutionProvider' in ort.get_available_providers(): |
| 296 | + device = 'mps' |
| 297 | + backend = 'onnxruntime' |
| 298 | + logging.info(f"\nValid MPS installation found: using ONNXRuntime backend with GPU.") |
| 299 | + else: |
| 300 | + raise |
| 301 | + except: |
| 302 | + device = 'cpu' |
| 303 | + backend = 'openvino' |
| 304 | + logging.info(f"\nNo valid CUDA installation found: using OpenVINO backend with CPU.") |
| 305 | + |
| 306 | + if det_frequency>1: |
| 307 | + logging.info(f'Inference run only every {det_frequency} frames. Inbetween, pose estimation tracks previously detected points.') |
| 308 | + elif det_frequency==1: |
| 309 | + logging.info(f'Inference run on every single frame.') |
| 310 | + else: |
| 311 | + raise ValueError(f"Invalid det_frequency: {det_frequency}. Must be an integer greater or equal to 1.") |
| 312 | + |
| 313 | + # Select the appropriate model based on the model_type |
| 314 | + if pose_model.upper() == 'HALPE_26': |
| 315 | + ModelClass = BodyWithFeet |
| 316 | + logging.info(f"Using HALPE_26 model (body and feet) for pose estimation.") |
| 317 | + elif pose_model.upper() == 'COCO_133': |
| 318 | + ModelClass = Wholebody |
| 319 | + logging.info(f"Using COCO_133 model (body, feet, hands, and face) for pose estimation.") |
| 320 | + elif pose_model.upper() == 'COCO_17': |
| 321 | + ModelClass = Body # 26 keypoints(halpe26) |
| 322 | + logging.info(f"Using COCO_17 model (body) for pose estimation.") |
| 323 | + else: |
| 324 | + raise ValueError(f"Invalid model_type: {pose_model}. Must be 'HALPE_26', 'COCO_133', or 'COCO_17'. Use another network (MMPose, DeepLabCut, OpenPose, AlphaPose, BlazePose...) and convert the output files if you need another model. See documentation.") |
| 325 | + logging.info(f'Mode: {mode}.\n') |
| 326 | + |
| 327 | + # Initialize the pose tracker with Halpe26 model |
| 328 | + pose_tracker = PoseTracker( |
| 329 | + ModelClass, |
| 330 | + det_frequency=det_frequency, |
| 331 | + mode=mode, |
| 332 | + backend=backend, |
| 333 | + device=device, |
| 334 | + tracking=False, |
| 335 | + to_openpose=False) |
| 336 | + |
| 337 | + return pose_tracker |
| 338 | + |
| 339 | +def setup_webcam(webcam_id, save_video, vid_output_path, input_size): |
| 340 | + ''' |
| 341 | + Set up webcam capture with OpenCV. |
| 342 | +
|
| 343 | + INPUTS: |
| 344 | + - webcam_id: int. The ID of the webcam to capture from |
| 345 | + - input_size: tuple. The size of the input frame (width, height) |
| 346 | +
|
| 347 | + OUTPUTS: |
| 348 | + - cap: cv2.VideoCapture. The webcam capture object |
| 349 | + - out_vid: cv2.VideoWriter. The video writer object |
| 350 | + - cam_width: int. The actual width of the webcam frame |
| 351 | + - cam_height: int. The actual height of the webcam frame |
| 352 | + - fps: int. The frame rate of the webcam |
| 353 | + ''' |
| 354 | + |
| 355 | + #, cv2.CAP_DSHOW launches faster but only works for windows and esc key does not work |
| 356 | + cap = cv2.VideoCapture(webcam_id) |
| 357 | + if not cap.isOpened(): |
| 358 | + raise ValueError(f"Error: Could not open webcam #{webcam_id}. Make sure that your webcam is available and has the right 'webcam_id' (check in your Config.toml file).") |
| 359 | + |
| 360 | + # set width and height to closest available for the webcam |
| 361 | + cap.set(cv2.CAP_PROP_FRAME_WIDTH, input_size[0]) |
| 362 | + cap.set(cv2.CAP_PROP_FRAME_HEIGHT, input_size[1]) |
| 363 | + cam_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) |
| 364 | + cam_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) |
| 365 | + fps = cap.get(cv2.CAP_PROP_FPS) |
| 366 | + if fps == 0: fps = 30 |
| 367 | + |
| 368 | + if cam_width != input_size[0] or cam_height != input_size[1]: |
| 369 | + logging.warning(f"Warning: Your webcam does not support {input_size[0]}x{input_size[1]} resolution. Resolution set to the closest supported one: {cam_width}x{cam_height}.") |
| 370 | + |
| 371 | + out_vid = None |
| 372 | + if save_video: |
| 373 | + # fourcc MJPG produces very large files but is faster. If it is too slow, consider using it and then converting the video to h264 |
| 374 | + # try: |
| 375 | + # fourcc = cv2.VideoWriter_fourcc(*'avc1') # =h264. better compression and quality but may fail on some systems |
| 376 | + # out_vid = cv2.VideoWriter(vid_output_path, fourcc, fps, (cam_width, cam_height)) |
| 377 | + # if not out_vid.isOpened(): |
| 378 | + # raise ValueError("Failed to open video writer with 'avc1' (h264)") |
| 379 | + # except Exception: |
| 380 | + fourcc = cv2.VideoWriter_fourcc(*'mp4v') |
| 381 | + out_vid = cv2.VideoWriter(vid_output_path, fourcc, fps, (cam_width, cam_height)) |
| 382 | + # logging.info("Failed to open video writer with 'avc1' (h264). Using 'mp4v' instead.") |
| 383 | + |
| 384 | + return cap, out_vid, cam_width, cam_height, fps |
| 385 | + |
| 386 | + |
| 387 | +def setup_video(video_file_path, save_video, vid_output_path): |
| 388 | + ''' |
| 389 | + Set up video capture with OpenCV. |
| 390 | +
|
| 391 | + INPUTS: |
| 392 | + - video_file_path: Path. The path to the video file |
| 393 | + - save_video: bool. Whether to save the video output |
| 394 | + - vid_output_path: Path. The path to save the video output |
| 395 | +
|
| 396 | + OUTPUTS: |
| 397 | + - cap: cv2.VideoCapture. The video capture object |
| 398 | + - out_vid: cv2.VideoWriter. The video writer object |
| 399 | + - cam_width: int. The width of the video |
| 400 | + - cam_height: int. The height of the video |
| 401 | + - fps: int. The frame rate of the video |
| 402 | + ''' |
| 403 | + |
| 404 | + if video_file_path.name == video_file_path.stem: |
| 405 | + raise ValueError("Please set video_input to 'webcam' or to a video file (with extension) in Config.toml") |
| 406 | + try: |
| 407 | + cap = cv2.VideoCapture(video_file_path) |
| 408 | + if not cap.isOpened(): |
| 409 | + raise |
| 410 | + except: |
| 411 | + raise NameError(f"{video_file_path} is not a video. Check video_dir and video_input in your Config.toml file.") |
| 412 | + |
| 413 | + cam_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) |
| 414 | + cam_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) |
| 415 | + fps = cap.get(cv2.CAP_PROP_FPS) |
| 416 | + if fps == 0: fps = 30 |
| 417 | + |
| 418 | + out_vid = None |
| 419 | + |
| 420 | + if save_video: |
| 421 | + # try: |
| 422 | + # fourcc = cv2.VideoWriter_fourcc(*'avc1') # =h264. better compression and quality but may fail on some systems |
| 423 | + # out_vid = cv2.VideoWriter(vid_output_path, fourcc, fps, (cam_width, cam_height)) |
| 424 | + # if not out_vid.isOpened(): |
| 425 | + # raise ValueError("Failed to open video writer with 'avc1' (h264)") |
| 426 | + # except Exception: |
| 427 | + fourcc = cv2.VideoWriter_fourcc(*'mp4v') |
| 428 | + out_vid = cv2.VideoWriter(vid_output_path, fourcc, fps, (cam_width, cam_height)) |
| 429 | + # logging.info("Failed to open video writer with 'avc1' (h264). Using 'mp4v' instead.") |
| 430 | + |
| 431 | + return cap, out_vid, cam_width, cam_height, fps |
0 commit comments