From 9df5014022a59e57ff72169c7e14f4091b517d98 Mon Sep 17 00:00:00 2001 From: DI3D <63136834+DI3D@users.noreply.github.com> Date: Fri, 5 Aug 2022 18:59:20 -0400 Subject: [PATCH] Adding reduction code (#40) * Create test * Add example code Add agent.py example and SB3_save.py example. These files implement a method to remove the optimizer from the file to reduce file size. The first is an example of how to implement the load function on the reduced file in RLBot (which I presume is the main use). The second shows how to save in reduced form. * Delete test --- .../sb3_utils/sb3_file_reducer/SB3_save.py | 86 +++++++++++++++++++ .../sb3_utils/sb3_file_reducer/agent.py | 34 ++++++++ 2 files changed, 120 insertions(+) create mode 100644 rlgym_tools/sb3_utils/sb3_file_reducer/SB3_save.py create mode 100644 rlgym_tools/sb3_utils/sb3_file_reducer/agent.py diff --git a/rlgym_tools/sb3_utils/sb3_file_reducer/SB3_save.py b/rlgym_tools/sb3_utils/sb3_file_reducer/SB3_save.py new file mode 100644 index 0000000..b68ab4b --- /dev/null +++ b/rlgym_tools/sb3_utils/sb3_file_reducer/SB3_save.py @@ -0,0 +1,86 @@ +from stable_baselines3 import PPO +from typing import Iterable, Optional, Union, Tuple, List +import io +import pathlib + +from stable_baselines3.common.save_util import recursive_getattr, save_to_zip_file + + +# Example implementation of load hack +class TestOverrideLoad(PPO): + def _get_torch_save_params(self) -> Tuple[List[str], List[str]]: + state_dicts = ["policy"] + + return state_dicts, [] + + +# Example implementation of save hack +class TestOverride(PPO): + def save( + self, + path: Union[str, pathlib.Path, io.BufferedIOBase], + exclude: Optional[Iterable[str]] = None, + include: Optional[Iterable[str]] = None, + ) -> None: + """ + Save all the attributes of the object and the model parameters in a zip-file. + + :param path: path to the file where the rl agent should be saved + :param exclude: name of parameters that should be excluded in addition to the default ones + :param include: name of parameters that might be excluded but should be included anyway + """ + # Copy parameter list so we don't mutate the original dict + data = self.__dict__.copy() + + # Copy so that we do not overwrite original exclude + exclude_original = exclude + + # Exclude is union of specified parameters (if any) and standard exclusions + if exclude is None: + exclude = [] + exclude = set(exclude).union(self._excluded_save_params()) + + # Do not exclude params if they are specifically included + if include is not None: + exclude = exclude.difference(include) + + state_dicts_names, torch_variable_names = self._get_torch_save_params() + all_pytorch_variables = state_dicts_names + torch_variable_names + for torch_var in all_pytorch_variables: + # We need to get only the name of the top most module as we'll remove that + var_name = torch_var.split(".")[0] + # Any params that are in the save vars must not be saved by data + exclude.add(var_name) + + # Remove parameter entries of parameters which are to be excluded + for param_name in exclude: + data.pop(param_name, None) + + # Build dict of torch variables + pytorch_variables = None + if torch_variable_names is not None: + pytorch_variables = {} + for name in torch_variable_names: + attr = recursive_getattr(self, name) + pytorch_variables[name] = attr + + # Build dict of state_dicts + params_to_save = self.get_parameters() + # So we don't get dict change errors + params_to_save_2 = params_to_save.copy() + + if params_to_save is not None: + for file_name, dict_ in params_to_save.items(): + for param_name in exclude_original: + if param_name == file_name: + params_to_save_2.pop(file_name) + + save_to_zip_file(path, data=data, params=params_to_save_2, pytorch_variables=pytorch_variables) + + +if __name__ == "__main__": + test_class = TestOverride.load("exit_save") + # Not sure if "optimizer" is needed + test_class.save("reduced_save", exclude=["policy.optimizer", "optimizer"]) + # Only for checking to make sure the save and load work, not needed + model = TestOverrideLoad.load("reduced_save") diff --git a/rlgym_tools/sb3_utils/sb3_file_reducer/agent.py b/rlgym_tools/sb3_utils/sb3_file_reducer/agent.py new file mode 100644 index 0000000..a467a51 --- /dev/null +++ b/rlgym_tools/sb3_utils/sb3_file_reducer/agent.py @@ -0,0 +1,34 @@ +# +# ONLY AN EXAMPLE +# + +from stable_baselines3 import PPO +import pathlib +from discrete_act import DiscreteAction + + +# This is so there are no optimizer load errors, hacky, but it works +class TestOverrideLoad(PPO): + def _get_torch_save_params(self): + state_dicts = ["policy"] + + return state_dicts, [] + + +class Agent: + def __init__(self): + _path = pathlib.Path(__file__).parent.resolve() + custom_objects = { + "lr_schedule": 0.000001, + "clip_range": .02, + "n_envs": 1, + } + + self.actor = TestOverrideLoad.load(str(_path) + '/example_mdl', device='cpu', custom_objects=custom_objects) + self.parser = DiscreteAction() + + def act(self, state): + action = self.actor.predict(state, deterministic=True) + x = self.parser.parse_actions(action[0], state) + + return x[0]