-
Notifications
You must be signed in to change notification settings - Fork 290
add video saving and uploading support to train_* scripts
#524
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
train_* scripts
AdamGleave
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Took a quick look, only skimmed as still in draft mode. Seems like a useful feature, couple of suggestions.
Codecov Report
@@ Coverage Diff @@
## master #524 +/- ##
==========================================
- Coverage 96.95% 96.93% -0.03%
==========================================
Files 84 84
Lines 7460 7369 -91
==========================================
- Hits 7233 7143 -90
+ Misses 227 226 -1
📣 We’re building smart automated test selection to slash your CI/CD build times. Learn more |
| ) | ||
| callback_objs.append(save_policy_callback) | ||
|
|
||
| if _config["train"]["videos"]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here we need to init a video_wrapper.SaveVideoCallback instead of using train.save_video like other scripts do. A bit unsatisfying.
An alternative could be passing a save_video partial function into the callback.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, this is strange, why is that the case? I would advocate for using the callback class everywhere or using a partial / closure+wrapper defined in this file for this specific instance. Currently the existence of the class is confusing and not documented.
| rl_algo.set_logger(custom_logger) | ||
| rl_algo.learn(total_timesteps, callback=callback) | ||
|
|
||
| with common.make_venv(num_vec=1, log_dir=None) as eval_venv: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Create an eval_venv
- with
num_vec=1. - without having creating monitors from here by setting
log_dir=None.
…eo-saving-during-training
|
I'm still a bit backlogged, @Rocamonde could you review this please? |
| total_timesteps = int(1e6) # total number of environment timesteps | ||
| total_comparisons = 5000 # total number of comparisons to elicit | ||
| num_iterations = 5 # Arbitrary, should be tuned for the task | ||
| num_iterations = 50 # Arbitrary, should be tuned for the task |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Apologies if this has been discussed, but why are you doing this?
| cross_entropy_loss_kwargs = {} | ||
| reward_trainer_kwargs = { | ||
| "epochs": 3, | ||
| "weight_decay": 0.0, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll have to remember changing this as I have a PR that replaces weight decay with a general regularization API (#481). @AdamGleave what do you think, should we merge my PR or this one first?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably best to merge your PR first, though really depends which one is ready earlier.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
#481 is ready and passing all the tests AFAIK.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| ) | ||
| callback_objs.append(save_policy_callback) | ||
|
|
||
| if _config["train"]["videos"]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, this is strange, why is that the case? I would advocate for using the callback class everywhere or using a partial / closure+wrapper defined in this file for this specific instance. Currently the existence of the class is confusing and not documented.
| total_timesteps: int, | ||
| total_comparisons: int, | ||
| callback: Optional[Callable[[int], None]] = None, | ||
| callback: Optional[Callable[[int, int], None]] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably should add in the docstring what the callback type signature represents.
|
|
||
|
|
||
| @train_ingredient.capture | ||
| def save_video( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When you call this function it self-documents as if the video were always saved. (but a flag indicating whether this should happen is magically injected through a decorator). I don't have an immediately better alternative, but perhaps a more explanatory function name could help.
| round_str: str, | ||
| ) -> None: | ||
| """Save discriminator and generator.""" | ||
| save_path = os.path.join(log_dir, "checkpoints", round_str) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have a PR for replacing os.path with pathlib in most places, but might as well keep it consistent for now until that's merged.
| """ | ||
| super().__init__(env) | ||
| self.episode_id = 0 | ||
| self._episode_id = 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why make it private?
| directory=video_dir, | ||
| **(video_kwargs or dict()), | ||
| ) | ||
| sample_until = rollout.make_sample_until(min_timesteps=None, min_episodes=1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I understand where the name of this function is coming from ("make the function called sample_until"), but how it actually reads IMO is "make the sample (until...?)". I think that refactoring this to something like "get_stopping_conditions_callback" or "get_sampling_termination_fn" would be much more readable.
| sample_until = rollout.make_sample_until(min_timesteps=None, min_episodes=1) | ||
| # video.{:06}.mp4".format(VideoWrapper.episode_id) will be saved within | ||
| # rollout.generate_trajectories() | ||
| rollout.generate_trajectories(policy, video_venv, sample_until) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For some reason I was expecting that the video that would be saved would be one of the real training trajectories instead of a newly sampled one.
|
Closing in favor of #597 |
Description
Closes #523.
Problem
scripts.train_rl,scripts.train_preference_comparisons,scripts.train_adversarialandscripts.train_bc.Solution
record_and_save_video()function inimitation.util.video_wrapperthat takes in a policy, eval_venv, and a logger to save the video of a policy evaluated on an environment to a designated path.WandbOutputFormat.write()by adding the following:Testing
tests/scripts/test_scripts.pytests/util/test_wb_logger.py