Skip to content

Commit

Permalink
request video clips
Browse files Browse the repository at this point in the history
  • Loading branch information
Haichao-Zhang committed Feb 18, 2025
1 parent e140cd2 commit 0590c63
Showing 1 changed file with 51 additions and 0 deletions.
51 changes: 51 additions & 0 deletions alf/trainers/policy_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import abc
from absl import logging
import cv2
from functools import partial
from typing import Dict
import math
Expand Down Expand Up @@ -317,6 +318,14 @@ def train(self):
"Use `kill -%s %s` to request checkpoint during training." %
(int(signal.SIGUSR2), self._pid))

self._video_clip_requested = False
if threading.current_thread() == threading.main_thread():
signal.signal(signal.SIGRTMIN, self._request_video_clip)
# kill -34 PID
logging.info(
"Use `kill -%s %s` to request video-clip during training." %
(int(signal.SIGRTMIN), self._pid))

if (threading.current_thread() == threading.main_thread()
and PerProcessContext().ddp_rank <= 0):
# Debugging in subprocesses is not supported because they don't have
Expand Down Expand Up @@ -472,6 +481,9 @@ def _markdownify(paragraph):
def _request_checkpoint(self, signum, frame):
self._checkpoint_requested = True

def _request_video_clip(self, signum, frame):
self._video_clip_requested = True

def _request_debug(self, signum, frame):
breakpoint()

Expand All @@ -482,6 +494,42 @@ def _save_checkpoint(self):
global_step = alf.summary.get_global_counter()
self._checkpointer.save(global_step=global_step)

def _save_video_clip(self, name: str = "video_clip"):
# Saving video clip is only enabled when running single process training
# (rank is -1) or master process of DDP training (rank is 0).
if self._rank <= 0:
env = alf.get_env()
# to use this function, the env.render() function should support ``online`` mode
frames = env.render(mode="online")
common.warning_once("Caching video frames...")
if len(frames) > 0:
video_dir = os.path.join(self._train_dir, 'video')
os.makedirs(video_dir, exist_ok=True)
# save video
logging.info("Saving video clip...")

height, width, _ = frames[0].shape
fps = 30
global_step = alf.summary.get_global_counter()
output_file = os.path.join(video_dir,
f"{name}_{global_step}.mp4")

# define the video writer (codec: MP4V for .mp4 files)
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
out = cv2.VideoWriter(output_file, fourcc, fps,
(width, height))

# write frames to the video file
for frame in frames:
out.write(cv2.cvtColor(
frame,
cv2.COLOR_RGB2BGR)) # Convert RGB to BGR for OpenCV

# release the writer
out.release()

self._video_clip_requested = False

def _restore_checkpoint(self, checkpointer):
"""Restore from saved checkpoint.
Expand Down Expand Up @@ -711,6 +759,9 @@ def _train(self):
logging.info("Saving checkpoint upon request...")
self._save_checkpoint()
self._checkpoint_requested = False
elif self._video_clip_requested:
common.warning_once("Saving video-clip upon request...")
self._save_video_clip()

if self._evaluate:
self._evaluator.wait_complete()
Expand Down

0 comments on commit 0590c63

Please sign in to comment.