-
Notifications
You must be signed in to change notification settings - Fork 1
/
ppo.py
208 lines (191 loc) · 8.79 KB
/
ppo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
from dataclasses import dataclass
from typing import Any, Generic, Literal, TypeVar
import gymnasium as gym
import numpy as np
import torch
from torch import nn
from data import ReplayBuffer, SequenceSummaryStats, to_torch_as
from data.types import LogpOldProtocol, RolloutBatchProtocol
from policy_set import A2CPolicy
from policy_set.base import TLearningRateScheduler, TrainingStats
from policy_set.modelfree import TDistributionFunction
from utils.net.common import ActorCritic
@dataclass(kw_only=True)
class PPOTrainingStats(TrainingStats):
loss: SequenceSummaryStats
clip_loss: SequenceSummaryStats
vf_loss: SequenceSummaryStats
ent_loss: SequenceSummaryStats
TPPOTrainingStats = TypeVar("TPPOTrainingStats", bound=PPOTrainingStats)
# TODO: the type ignore here is needed b/c the hierarchy is actually broken! Should reconsider the inheritance structure.
class PPOPolicy(A2CPolicy[TPPOTrainingStats], Generic[TPPOTrainingStats]): # type: ignore[type-var]
r"""Implementation of Proximal Policy Optimization. arXiv:1707.06347.
:param actor: the actor network following the rules in BasePolicy. (s -> logits)
:param critic: the critic network. (s -> V(s))
:param optim: the optimizer for actor and critic network.
:param dist_fn: distribution class for computing the action.
:param action_space: env's action space
:param eps_clip: :math:`\epsilon` in :math:`L_{CLIP}` in the original
paper.
:param dual_clip: a parameter c mentioned in arXiv:1912.09729 Equ. 5,
where c > 1 is a constant indicating the lower bound. Set to None
to disable dual-clip PPO.
:param value_clip: a parameter mentioned in arXiv:1811.02553v3 Sec. 4.1.
:param advantage_normalization: whether to do per mini-batch advantage
normalization.
:param recompute_advantage: whether to recompute advantage every update
repeat according to https://arxiv.org/pdf/2006.05990.pdf Sec. 3.5.
:param vf_coef: weight for value loss.
:param ent_coef: weight for entropy loss.
:param max_grad_norm: clipping gradients in back propagation.
:param gae_lambda: in [0, 1], param for Generalized Advantage Estimation.
:param max_batchsize: the maximum size of the batch when computing GAE.
:param discount_factor: in [0, 1].
:param reward_normalization: normalize estimated values to have std close to 1.
:param deterministic_eval: if True, use deterministic evaluation.
:param observation_space: the space of the observation.
:param action_scaling: if True, scale the action from [-1, 1] to the range of
action_space. Only used if the action_space is continuous.
:param action_bound_method: method to bound action to range [-1, 1].
:param lr_scheduler: if not None, will be called in `policy.update()`.
.. seealso::
Please refer to :class:`~tianshou.policy.BasePolicy` for more detailed
explanation.
"""
def __init__(
self,
*,
actor: torch.nn.Module,
critic: torch.nn.Module,
optim: torch.optim.Optimizer,
dist_fn: TDistributionFunction,
action_space: gym.Space,
eps_clip: float = 0.2,
dual_clip: float | None = None,
value_clip: bool = False,
advantage_normalization: bool = True,
recompute_advantage: bool = False,
vf_coef: float = 0.5,
ent_coef: float = 0.01,
max_grad_norm: float | None = None,
gae_lambda: float = 0.95,
max_batchsize: int = 256,
discount_factor: float = 0.99,
# TODO: rename to return_normalization?
reward_normalization: bool = False,
deterministic_eval: bool = False,
observation_space: gym.Space | None = None,
action_scaling: bool = True,
action_bound_method: Literal["clip", "tanh"] | None = "clip",
lr_scheduler: TLearningRateScheduler | None = None,
) -> None:
assert (
dual_clip is None or dual_clip > 1.0
), f"Dual-clip PPO parameter should greater than 1.0 but got {dual_clip}"
super().__init__(
actor=actor,
critic=critic,
optim=optim,
dist_fn=dist_fn,
action_space=action_space,
vf_coef=vf_coef,
ent_coef=ent_coef,
max_grad_norm=max_grad_norm,
gae_lambda=gae_lambda,
max_batchsize=max_batchsize,
discount_factor=discount_factor,
reward_normalization=reward_normalization,
deterministic_eval=deterministic_eval,
observation_space=observation_space,
action_scaling=action_scaling,
action_bound_method=action_bound_method,
lr_scheduler=lr_scheduler,
)
self.eps_clip = eps_clip
self.dual_clip = dual_clip
self.value_clip = value_clip
self.norm_adv = advantage_normalization
self.recompute_adv = recompute_advantage
self._actor_critic: ActorCritic
def process_fn(
self,
batch: RolloutBatchProtocol,
buffer: ReplayBuffer,
indices: np.ndarray,
) -> LogpOldProtocol:
if self.recompute_adv:
# buffer input `buffer` and `indices` to be used in `learn()`.
self._buffer, self._indices = buffer, indices
batch = self._compute_returns(batch, buffer, indices)
batch.act = to_torch_as(batch.act, batch.v_s)
with torch.no_grad():
batch.logp_old = self(batch).dist.log_prob(batch.act)
batch: LogpOldProtocol
return batch
# TODO: why does mypy complain?
def learn( # type: ignore
self,
batch: RolloutBatchProtocol,
batch_size: int | None,
repeat: int,
*args: Any,
**kwargs: Any,
) -> TPPOTrainingStats:
losses, clip_losses, vf_losses, ent_losses = [], [], [], []
split_batch_size = batch_size or -1
for step in range(repeat):
if self.recompute_adv and step > 0:
batch = self._compute_returns(batch, self._buffer, self._indices)
for minibatch in batch.split(split_batch_size, merge_last=True):
# calculate loss for actor
dist = self(minibatch).dist
if self.norm_adv:
mean, std = minibatch.adv.mean(), minibatch.adv.std()
minibatch.adv = (minibatch.adv - mean) / (std + self._eps) # per-batch norm
ratio = (dist.log_prob(minibatch.act) - minibatch.logp_old).exp().float()
ratio = ratio.reshape(ratio.size(0), -1).transpose(0, 1)
surr1 = ratio * minibatch.adv
surr2 = ratio.clamp(1.0 - self.eps_clip, 1.0 + self.eps_clip) * minibatch.adv
if self.dual_clip:
clip1 = torch.min(surr1, surr2)
clip2 = torch.max(clip1, self.dual_clip * minibatch.adv)
clip_loss = -torch.where(minibatch.adv < 0, clip2, clip1).mean()
else:
clip_loss = -torch.min(surr1, surr2).mean()
# calculate loss for critic
value = self.critic(minibatch.obs).flatten()
if self.value_clip:
v_clip = minibatch.v_s + (value - minibatch.v_s).clamp(
-self.eps_clip,
self.eps_clip,
)
vf1 = (minibatch.returns - value).pow(2)
vf2 = (minibatch.returns - v_clip).pow(2)
vf_loss = torch.max(vf1, vf2).mean()
else:
vf_loss = (minibatch.returns - value).pow(2).mean()
# calculate regularization and overall loss
ent_loss = dist.entropy().mean()
loss = clip_loss + self.vf_coef * vf_loss - self.ent_coef * ent_loss
self.optim.zero_grad()
loss.backward()
if self.max_grad_norm: # clip large gradient
nn.utils.clip_grad_norm_(
self._actor_critic.parameters(),
max_norm=self.max_grad_norm,
)
self.optim.step()
clip_losses.append(clip_loss.item())
vf_losses.append(vf_loss.item())
ent_losses.append(ent_loss.item())
losses.append(loss.item())
losses_summary = SequenceSummaryStats.from_sequence(losses)
clip_losses_summary = SequenceSummaryStats.from_sequence(clip_losses)
vf_losses_summary = SequenceSummaryStats.from_sequence(vf_losses)
ent_losses_summary = SequenceSummaryStats.from_sequence(ent_losses)
return PPOTrainingStats( # type: ignore[return-value]
loss=losses_summary,
clip_loss=clip_losses_summary,
vf_loss=vf_losses_summary,
ent_loss=ent_losses_summary,
)