Skip to content

Commit

Permalink
ddpg huggingface integration (vwxyzjn#407)
Browse files Browse the repository at this point in the history
* add huggingface integration to ddpg

* ddpg_jax huggingface integration

* update docs

* update tests

* fix type-hint in py38

* fix type-hint in py38

* try fix

* retry

* reretry

* rereretry

* deprecated macos mujoco env tests
  • Loading branch information
sdpkjc authored Jul 4, 2023
1 parent 6c3e3c5 commit f62ad1f
Show file tree
Hide file tree
Showing 8 changed files with 253 additions and 28 deletions.
28 changes: 0 additions & 28 deletions .github/workflows/tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -204,34 +204,6 @@ jobs:
- name: Run mujoco tests
run: poetry run pytest tests/test_mujoco.py

test-mujoco-gymnasium-mac:
strategy:
fail-fast: false
matrix:
python-version: [3.8]
poetry-version: [1.3.1]
os: [macos-latest]
runs-on: ${{ matrix.os }}
steps:
- uses: actions/checkout@v2
- uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Run image
uses: abatilo/actions-poetry@v2.0.0
with:
poetry-version: ${{ matrix.poetry-version }}

# mujoco tests
- name: Install dependencies
run: poetry install -E "pytest mujoco dm_control jax"
- name: Downgrade setuptools
run: poetry run pip install setuptools==59.5.0
- name: Run gymnasium migration dependencies
run: poetry run pip install "stable_baselines3==2.0.0a1"
- name: Run mujoco tests
run: poetry run pytest tests/test_mujoco_gymnasium.py

test-mujoco_py-envs:
strategy:
fail-fast: false
Expand Down
32 changes: 32 additions & 0 deletions cleanrl/ddpg_continuous_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,12 @@ def parse_args():
help="the entity (team) of wandb's project")
parser.add_argument("--capture-video", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
help="whether to capture videos of the agent performances (check out `videos` folder)")
parser.add_argument("--save-model", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
help="whether to save model into the `runs/{run_name}` folder")
parser.add_argument("--upload-model", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
help="whether to upload the saved model to huggingface")
parser.add_argument("--hf-entity", type=str, default="",
help="the user or org name of the model repository from the Hugging Face Hub")

# Algorithm specific arguments
parser.add_argument("--env-id", type=str, default="HalfCheetah-v4",
Expand Down Expand Up @@ -247,5 +253,31 @@ def forward(self, x):
print("SPS:", int(global_step / (time.time() - start_time)))
writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step)

if args.save_model:
model_path = f"runs/{run_name}/{args.exp_name}.cleanrl_model"
torch.save((actor.state_dict(), qf1.state_dict()), model_path)
print(f"model saved to {model_path}")
from cleanrl_utils.evals.ddpg_eval import evaluate

episodic_returns = evaluate(
model_path,
make_env,
args.env_id,
eval_episodes=10,
run_name=f"{run_name}-eval",
Model=(Actor, QNetwork),
device=device,
exploration_noise=args.exploration_noise,
)
for idx, episodic_return in enumerate(episodic_returns):
writer.add_scalar("eval/episodic_return", episodic_return, idx)

if args.upload_model:
from cleanrl_utils.huggingface import push_to_hub

repo_name = f"{args.env_id}-{args.exp_name}-seed{args.seed}"
repo_id = f"{args.hf_entity}/{repo_name}" if args.hf_entity else repo_name
push_to_hub(args, episodic_returns, repo_id, "DDPG", f"runs/{run_name}", f"videos/{run_name}-eval")

envs.close()
writer.close()
39 changes: 39 additions & 0 deletions cleanrl/ddpg_continuous_action_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,12 @@ def parse_args():
help="the entity (team) of wandb's project")
parser.add_argument("--capture-video", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
help="whether to capture videos of the agent performances (check out `videos` folder)")
parser.add_argument("--save-model", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
help="whether to save model into the `runs/{run_name}` folder")
parser.add_argument("--upload-model", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
help="whether to upload the saved model to huggingface")
parser.add_argument("--hf-entity", type=str, default="",
help="the user or org name of the model repository from the Hugging Face Hub")

# Algorithm specific arguments
parser.add_argument("--env-id", type=str, default="HalfCheetah-v4",
Expand Down Expand Up @@ -292,5 +298,38 @@ def actor_loss(params):
print("SPS:", int(global_step / (time.time() - start_time)))
writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step)

if args.save_model:
model_path = f"runs/{run_name}/{args.exp_name}.cleanrl_model"
with open(model_path, "wb") as f:
f.write(
flax.serialization.to_bytes(
[
actor_state.params,
qf1_state.params,
]
)
)
print(f"model saved to {model_path}")
from cleanrl_utils.evals.ddpg_jax_eval import evaluate

episodic_returns = evaluate(
model_path,
make_env,
args.env_id,
eval_episodes=10,
run_name=f"{run_name}-eval",
Model=(Actor, QNetwork),
exploration_noise=args.exploration_noise,
)
for idx, episodic_return in enumerate(episodic_returns):
writer.add_scalar("eval/episodic_return", episodic_return, idx)

if args.upload_model:
from cleanrl_utils.huggingface import push_to_hub

repo_name = f"{args.env_id}-{args.exp_name}-seed{args.seed}"
repo_id = f"{args.hf_entity}/{repo_name}" if args.hf_entity else repo_name
push_to_hub(args, episodic_returns, repo_id, "DDPG", f"runs/{run_name}", f"videos/{run_name}-eval")

envs.close()
writer.close()
66 changes: 66 additions & 0 deletions cleanrl_utils/evals/ddpg_eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
from typing import Callable

import gymnasium as gym
import torch
import torch.nn as nn


def evaluate(
model_path: str,
make_env: Callable,
env_id: str,
eval_episodes: int,
run_name: str,
Model: nn.Module,
device: torch.device = torch.device("cpu"),
capture_video: bool = True,
exploration_noise: float = 0.1,
):
envs = gym.vector.SyncVectorEnv([make_env(env_id, 0, 0, capture_video, run_name)])
actor = Model[0](envs).to(device)
qf = Model[1](envs).to(device)
actor_params, qf_params = torch.load(model_path, map_location=device)
actor.load_state_dict(actor_params)
actor.eval()
qf.load_state_dict(qf_params)
qf.eval()
# note: qf is not used in this script

obs, _ = envs.reset()
episodic_returns = []
while len(episodic_returns) < eval_episodes:
with torch.no_grad():
actions = actor(torch.Tensor(obs).to(device))
actions += torch.normal(0, actor.action_scale * exploration_noise)
actions = actions.cpu().numpy().clip(envs.single_action_space.low, envs.single_action_space.high)

next_obs, _, _, _, infos = envs.step(actions)
if "final_info" in infos:
for info in infos["final_info"]:
if "episode" not in info:
continue
print(f"eval_episode={len(episodic_returns)}, episodic_return={info['episode']['r']}")
episodic_returns += [info["episode"]["r"]]
obs = next_obs

return episodic_returns


if __name__ == "__main__":
from huggingface_hub import hf_hub_download

from cleanrl.ddpg_continuous_action import Actor, QNetwork, make_env

model_path = hf_hub_download(
repo_id="cleanrl/HalfCheetah-v4-ddpg_continuous_action-seed1", filename="ddpg_continuous_action.cleanrl_model"
)
evaluate(
model_path,
make_env,
"HalfCheetah-v4",
eval_episodes=10,
run_name=f"eval",
Model=(Actor, QNetwork),
device="cpu",
capture_video=False,
)
82 changes: 82 additions & 0 deletions cleanrl_utils/evals/ddpg_jax_eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
from typing import Callable

import flax
import flax.linen as nn
import gymnasium as gym
import jax
import numpy as np


def evaluate(
model_path: str,
make_env: Callable,
env_id: str,
eval_episodes: int,
run_name: str,
Model: nn.Module,
capture_video: bool = True,
exploration_noise: float = 0.1,
seed=1,
):
envs = gym.vector.SyncVectorEnv([make_env(env_id, 0, 0, capture_video, run_name)])
obs, _ = envs.reset()

Actor, QNetwork = Model
action_scale = np.array((envs.action_space.high - envs.action_space.low) / 2.0)
action_bias = np.array((envs.action_space.high + envs.action_space.low) / 2.0)
actor = Actor(
action_dim=np.prod(envs.single_action_space.shape),
action_scale=action_scale,
action_bias=action_bias,
)
qf = QNetwork()
key = jax.random.PRNGKey(seed)
key, actor_key, qf_key = jax.random.split(key, 3)
actor_params = actor.init(actor_key, obs)
qf_params = qf.init(qf_key, obs, envs.action_space.sample())
# note: qf_params is not used in this script
with open(model_path, "rb") as f:
(actor_params, qf_params) = flax.serialization.from_bytes((actor_params, qf_params), f.read())
actor.apply = jax.jit(actor.apply)
qf.apply = jax.jit(qf.apply)

episodic_returns = []
while len(episodic_returns) < eval_episodes:
actions = actor.apply(actor_params, obs)
actions = np.array(
[
(jax.device_get(actions)[0] + np.random.normal(0, action_scale * exploration_noise)[0]).clip(
envs.single_action_space.low, envs.single_action_space.high
)
]
)

next_obs, _, _, _, infos = envs.step(actions)
if "final_info" in infos:
for info in infos["final_info"]:
if "episode" not in info:
continue
print(f"eval_episode={len(episodic_returns)}, episodic_return={info['episode']['r']}")
episodic_returns += [info["episode"]["r"]]
obs = next_obs

return episodic_returns


if __name__ == "__main__":
from huggingface_hub import hf_hub_download

from cleanrl.ddpg_continuous_action_jax import Actor, QNetwork, make_env

model_path = hf_hub_download(
repo_id="cleanrl/HalfCheetah-v4-ddpg_continuous_action_jax-seed1", filename="ddpg_continuous_action_jax.cleanrl_model"
)
evaluate(
model_path,
make_env,
"HalfCheetah-v4",
eval_episodes=10,
run_name=f"eval",
Model=(Actor, QNetwork),
exploration_noise=0.1,
)
2 changes: 2 additions & 0 deletions docs/get-started/zoo.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ CleanRL now has 🧪 experimental support for saving and loading models from
| | :material-github: [`c51_atari.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/c51_atari.py), :material-file-document: [docs](/rl-algorithms/c51/#c51_ataripy) |
| | :material-github: [`c51_jax.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/c51_jax.py), :material-file-document: [docs](/rl-algorithms/c51/#c51_jaxpy) |
| | :material-github: [`c51_atari_jax.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/c51_atari_jax.py), :material-file-document: [docs](/rl-algorithms/c51/#c51_atari_jaxpy) |
|[Deep Deterministic Policy Gradient (DDPG)](https://arxiv.org/pdf/1509.02971.pdf) | :material-github: [`ddpg_continuous_action.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/ddpg_continuous_action.py), :material-file-document: [docs](/rl-algorithms/ddpg/#ddpg_continuous_actionpy) |
| | :material-github: [`ddpg_continuous_action_jax.py`](https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/ddpg_continuous_action_jax.py), :material-file-document: [docs](/rl-algorithms/ddpg/#ddpg_continuous_action_jaxpy)


## Load models from the Model Hub
Expand Down
16 changes: 16 additions & 0 deletions tests/test_mujoco_gymnasium.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,19 @@ def test_mujoco():
shell=True,
check=True,
)


def test_mujoco_eval():
"""
Test mujoco_eval
"""
subprocess.run(
"python cleanrl/ddpg_continuous_action.py --save-model True --env-id Hopper-v4 --learning-starts 100 --batch-size 32 --total-timesteps 105",
shell=True,
check=True,
)
subprocess.run(
"python cleanrl/ddpg_continuous_action_jax.py --save-model True --env-id Hopper-v4 --learning-starts 100 --batch-size 32 --total-timesteps 105",
shell=True,
check=True,
)
16 changes: 16 additions & 0 deletions tests/test_mujoco_py_gymnasium.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,19 @@ def test_mujoco_py():
shell=True,
check=True,
)


def test_mujoco_py_eval():
"""
Test mujoco_py_eval
"""
subprocess.run(
"python cleanrl/ddpg_continuous_action.py --save-model True --env-id Hopper-v2 --learning-starts 100 --batch-size 32 --total-timesteps 105",
shell=True,
check=True,
)
subprocess.run(
"python cleanrl/ddpg_continuous_action_jax.py --save-model True --env-id Hopper-v2 --learning-starts 100 --batch-size 32 --total-timesteps 105",
shell=True,
check=True,
)

0 comments on commit f62ad1f

Please sign in to comment.