Skip to content

Commit f620d03

Browse files
Improving functions interoperability
1 parent ebefa82 commit f620d03

File tree

3 files changed

+238
-214
lines changed

3 files changed

+238
-214
lines changed

Sports2D/Sports2D.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -254,8 +254,8 @@ def base_params(config_dict):
254254
# videod_dir and result_dir
255255
video_dir = Path(config_dict.get('project').get('video_dir')).resolve()
256256
if video_dir == '': video_dir = Path.cwd()
257-
result_dir = Path(config_dict.get('process').get('result_dir')).resolve()
258-
if result_dir == '': result_dir = Path.cwd()
257+
output_dir = Path(config_dict.get('process').get('result_dir')).resolve()
258+
if output_dir == '': output_dir = Path.cwd()
259259

260260
# video_files, frame_rates, time_ranges, frame_ranges
261261
video_input = config_dict.get('project').get('video_input')
@@ -306,7 +306,7 @@ def base_params(config_dict):
306306
else:
307307
raise ValueError('Frame range must be [] for analysing all frames of all videos, or [start_frame, end_frame] for analysing all videos from start_frame to end_frame, or [[start_frame1, end_frame1], [start_frame2, end_frame2], ...] for analysing each video for a different frame_range.')
308308

309-
return video_dir, video_files, time_ranges, frame_ranges, frame_rates, result_dir
309+
return video_dir, video_files, time_ranges, frame_ranges, frame_rates, output_dir
310310

311311

312312
def get_leaf_keys(config, prefix=''):
@@ -372,17 +372,22 @@ def process(config='Config_demo.toml'):
372372
'''
373373

374374
from Sports2D.process import process_fun
375+
from Sports2D.Utilities.common import setup_pose_tracker
375376

376377
if type(config) == dict:
377378
config_dict = config
378379
else:
379380
config_dict = read_config_file(config)
380-
video_dir, video_files, time_ranges, frame_ranges, frame_rates, result_dir = base_params(config_dict)
381+
382+
mode = config_dict.get('pose').get('mode')
383+
det_frequency = config_dict.get('pose').get('det_frequency')
384+
385+
video_dir, video_files, time_ranges, frame_ranges, frame_rates, output_dir = base_params(config_dict)
381386

382-
result_dir.mkdir(parents=True, exist_ok=True)
383-
with open(result_dir / 'logs.txt', 'a+'): pass
387+
output_dir.mkdir(parents=True, exist_ok=True)
388+
with open(output_dir / 'logs.txt', 'a+'): pass
384389
logging.basicConfig(format='%(message)s', level=logging.INFO, force=True,
385-
handlers = [logging.handlers.TimedRotatingFileHandler(result_dir / 'logs.txt', when='D', interval=7), logging.StreamHandler()])
390+
handlers = [logging.handlers.TimedRotatingFileHandler(output_dir / 'logs.txt', when='D', interval=7), logging.StreamHandler()])
386391

387392
for video_file, time_range, frame_range, frame_rate in zip(video_files, time_ranges, frame_ranges, frame_rates):
388393
currentDateAndTime = datetime.now()
@@ -405,7 +410,15 @@ def process(config='Config_demo.toml'):
405410
logging.info(f"On {currentDateAndTime.strftime('%A %d. %B %Y, %H:%M:%S')}")
406411
logging.info("---------------------------------------------------------------------")
407412

408-
process_fun(config_dict, video_file, frame_range, result_dir)
413+
if video_file != "webcam":
414+
video_file_path = video_dir / video_file
415+
416+
pose_tracker = setup_pose_tracker(det_frequency, mode)
417+
418+
logging.info(f'Pose tracking set up for BodyWithFeet model in {mode} mode.')
419+
logging.info(f'Persons are detected every {det_frequency} frames and tracked inbetween.')
420+
421+
process_fun(config_dict, video_file_path, pose_tracker, frame_range, output_dir)
409422

410423
elapsed_time = (datetime.now() - currentDateAndTime).total_seconds()
411424
logging.info(f'\nProcessing {video_file} took {elapsed_time:.2f} s.')

Sports2D/Utilities/common.py

Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,15 @@
1717
## INIT
1818
import re
1919
import sys
20+
import cv2
2021
import subprocess
22+
import logging
2123
from pathlib import Path
2224

2325
import numpy as np
2426
from scipy import interpolate
2527
import imageio_ffmpeg as ffmpeg
28+
from rtmlib import PoseTracker, BodyWithFeet, Wholebody, Body
2629

2730
from PyQt5.QtWidgets import QMainWindow, QApplication, QWidget, QTabWidget, QVBoxLayout
2831
from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas
@@ -257,3 +260,172 @@ def euclidean_distance(q1, q2):
257260
euc_dist = np.sqrt(np.sum( [d**2 for d in dist]))
258261

259262
return euc_dist
263+
264+
def setup_pose_tracker(det_frequency, mode, pose_model = "HALPE_26"):
265+
'''
266+
Set up the RTMLib pose tracker with the appropriate model and backend.
267+
If CUDA is available, use it with ONNXRuntime backend; else use CPU with openvino
268+
269+
INPUTS:
270+
- det_frequency: int. The frequency of pose detection (every N frames)
271+
- mode: str. The mode of the pose tracker ('lightweight', 'balanced', 'performance')
272+
- tracking: bool. Whether to track persons across frames with RTMlib tracker
273+
274+
OUTPUTS:
275+
- pose_tracker: PoseTracker. The initialized pose tracker object
276+
'''
277+
278+
# If CUDA is available, use it with ONNXRuntime backend; else use CPU with openvino
279+
try:
280+
import torch
281+
import onnxruntime as ort
282+
if torch.cuda.is_available() == True and 'CUDAExecutionProvider' in ort.get_available_providers():
283+
device = 'cuda'
284+
backend = 'onnxruntime'
285+
logging.info(f"\nValid CUDA installation found: using ONNXRuntime backend with GPU.")
286+
elif torch.cuda.is_available() == True and 'ROCMExecutionProvider' in ort.get_available_providers():
287+
device = 'rocm'
288+
backend = 'onnxruntime'
289+
logging.info(f"\nValid ROCM installation found: using ONNXRuntime backend with GPU.")
290+
else:
291+
raise
292+
except:
293+
try:
294+
import onnxruntime as ort
295+
if 'MPSExecutionProvider' in ort.get_available_providers() or 'CoreMLExecutionProvider' in ort.get_available_providers():
296+
device = 'mps'
297+
backend = 'onnxruntime'
298+
logging.info(f"\nValid MPS installation found: using ONNXRuntime backend with GPU.")
299+
else:
300+
raise
301+
except:
302+
device = 'cpu'
303+
backend = 'openvino'
304+
logging.info(f"\nNo valid CUDA installation found: using OpenVINO backend with CPU.")
305+
306+
if det_frequency>1:
307+
logging.info(f'Inference run only every {det_frequency} frames. Inbetween, pose estimation tracks previously detected points.')
308+
elif det_frequency==1:
309+
logging.info(f'Inference run on every single frame.')
310+
else:
311+
raise ValueError(f"Invalid det_frequency: {det_frequency}. Must be an integer greater or equal to 1.")
312+
313+
# Select the appropriate model based on the model_type
314+
if pose_model.upper() == 'HALPE_26':
315+
ModelClass = BodyWithFeet
316+
logging.info(f"Using HALPE_26 model (body and feet) for pose estimation.")
317+
elif pose_model.upper() == 'COCO_133':
318+
ModelClass = Wholebody
319+
logging.info(f"Using COCO_133 model (body, feet, hands, and face) for pose estimation.")
320+
elif pose_model.upper() == 'COCO_17':
321+
ModelClass = Body # 26 keypoints(halpe26)
322+
logging.info(f"Using COCO_17 model (body) for pose estimation.")
323+
else:
324+
raise ValueError(f"Invalid model_type: {pose_model}. Must be 'HALPE_26', 'COCO_133', or 'COCO_17'. Use another network (MMPose, DeepLabCut, OpenPose, AlphaPose, BlazePose...) and convert the output files if you need another model. See documentation.")
325+
logging.info(f'Mode: {mode}.\n')
326+
327+
# Initialize the pose tracker with Halpe26 model
328+
pose_tracker = PoseTracker(
329+
ModelClass,
330+
det_frequency=det_frequency,
331+
mode=mode,
332+
backend=backend,
333+
device=device,
334+
tracking=False,
335+
to_openpose=False)
336+
337+
return pose_tracker
338+
339+
def setup_webcam(webcam_id, save_video, vid_output_path, input_size):
340+
'''
341+
Set up webcam capture with OpenCV.
342+
343+
INPUTS:
344+
- webcam_id: int. The ID of the webcam to capture from
345+
- input_size: tuple. The size of the input frame (width, height)
346+
347+
OUTPUTS:
348+
- cap: cv2.VideoCapture. The webcam capture object
349+
- out_vid: cv2.VideoWriter. The video writer object
350+
- cam_width: int. The actual width of the webcam frame
351+
- cam_height: int. The actual height of the webcam frame
352+
- fps: int. The frame rate of the webcam
353+
'''
354+
355+
#, cv2.CAP_DSHOW launches faster but only works for windows and esc key does not work
356+
cap = cv2.VideoCapture(webcam_id)
357+
if not cap.isOpened():
358+
raise ValueError(f"Error: Could not open webcam #{webcam_id}. Make sure that your webcam is available and has the right 'webcam_id' (check in your Config.toml file).")
359+
360+
# set width and height to closest available for the webcam
361+
cap.set(cv2.CAP_PROP_FRAME_WIDTH, input_size[0])
362+
cap.set(cv2.CAP_PROP_FRAME_HEIGHT, input_size[1])
363+
cam_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
364+
cam_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
365+
fps = cap.get(cv2.CAP_PROP_FPS)
366+
if fps == 0: fps = 30
367+
368+
if cam_width != input_size[0] or cam_height != input_size[1]:
369+
logging.warning(f"Warning: Your webcam does not support {input_size[0]}x{input_size[1]} resolution. Resolution set to the closest supported one: {cam_width}x{cam_height}.")
370+
371+
out_vid = None
372+
if save_video:
373+
# fourcc MJPG produces very large files but is faster. If it is too slow, consider using it and then converting the video to h264
374+
# try:
375+
# fourcc = cv2.VideoWriter_fourcc(*'avc1') # =h264. better compression and quality but may fail on some systems
376+
# out_vid = cv2.VideoWriter(vid_output_path, fourcc, fps, (cam_width, cam_height))
377+
# if not out_vid.isOpened():
378+
# raise ValueError("Failed to open video writer with 'avc1' (h264)")
379+
# except Exception:
380+
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
381+
out_vid = cv2.VideoWriter(vid_output_path, fourcc, fps, (cam_width, cam_height))
382+
# logging.info("Failed to open video writer with 'avc1' (h264). Using 'mp4v' instead.")
383+
384+
return cap, out_vid, cam_width, cam_height, fps
385+
386+
387+
def setup_video(video_file_path, save_video, vid_output_path):
388+
'''
389+
Set up video capture with OpenCV.
390+
391+
INPUTS:
392+
- video_file_path: Path. The path to the video file
393+
- save_video: bool. Whether to save the video output
394+
- vid_output_path: Path. The path to save the video output
395+
396+
OUTPUTS:
397+
- cap: cv2.VideoCapture. The video capture object
398+
- out_vid: cv2.VideoWriter. The video writer object
399+
- cam_width: int. The width of the video
400+
- cam_height: int. The height of the video
401+
- fps: int. The frame rate of the video
402+
'''
403+
404+
if video_file_path.name == video_file_path.stem:
405+
raise ValueError("Please set video_input to 'webcam' or to a video file (with extension) in Config.toml")
406+
try:
407+
cap = cv2.VideoCapture(video_file_path)
408+
if not cap.isOpened():
409+
raise
410+
except:
411+
raise NameError(f"{video_file_path} is not a video. Check video_dir and video_input in your Config.toml file.")
412+
413+
cam_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
414+
cam_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
415+
fps = cap.get(cv2.CAP_PROP_FPS)
416+
if fps == 0: fps = 30
417+
418+
out_vid = None
419+
420+
if save_video:
421+
# try:
422+
# fourcc = cv2.VideoWriter_fourcc(*'avc1') # =h264. better compression and quality but may fail on some systems
423+
# out_vid = cv2.VideoWriter(vid_output_path, fourcc, fps, (cam_width, cam_height))
424+
# if not out_vid.isOpened():
425+
# raise ValueError("Failed to open video writer with 'avc1' (h264)")
426+
# except Exception:
427+
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
428+
out_vid = cv2.VideoWriter(vid_output_path, fourcc, fps, (cam_width, cam_height))
429+
# logging.info("Failed to open video writer with 'avc1' (h264). Using 'mp4v' instead.")
430+
431+
return cap, out_vid, cam_width, cam_height, fps

0 commit comments

Comments
 (0)