-
Notifications
You must be signed in to change notification settings - Fork 131
/
spot.py
918 lines (778 loc) · 32.2 KB
/
spot.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
# source: https://github.com/thuml/SPOT/tree/58c591dc48fbd9ff632b7494eab4caf778e86f4a
# https://arxiv.org/pdf/2202.06239.pdf
import copy
import os
import random
import uuid
from dataclasses import asdict, dataclass
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union
import d4rl
import gym
import numpy as np
import pyrallis
import torch
import torch.distributions as td
import torch.nn as nn
import torch.nn.functional as F
import wandb
TensorBatch = List[torch.Tensor]
ENVS_WITH_GOAL = ("antmaze", "pen", "door", "hammer", "relocate")
@dataclass
class TrainConfig:
# Experiment
device: str = "cuda"
env: str = "antmaze-umaze-v2" # OpenAI gym environment name
seed: int = 0 # Sets Gym, PyTorch and Numpy seeds
eval_seed: int = 0 # Eval environment seed
eval_freq: int = int(5e3) # How often (time steps) we evaluate
n_episodes: int = 10 # How many episodes run during evaluation
offline_iterations: int = int(1e6) # Number of offline updates
online_iterations: int = int(1e6) # Number of online updates
checkpoints_path: Optional[str] = None # Save path
load_model: str = "" # Model load file name, "" doesn't load
# TD3
actor_lr: float = 1e-4 # Actor learning ratev
critic_lr: float = 3e-4 # Actor learning rate
buffer_size: int = 20_000_000 # Replay buffer size
batch_size: int = 256 # Batch size for all networks
discount: float = 0.99 # Discount factor
expl_noise: float = 0.1 # Std of Gaussian exploration noise
tau: float = 0.005 # Target network update rate
policy_noise: float = 0.2 # Noise added to target actor during critic update
noise_clip: float = 0.5 # Range to clip target actor noise
policy_freq: int = 2 # Frequency of delayed actor updates
# SPOT VAE
vae_lr: float = 1e-3 # VAE learning rate
vae_hidden_dim: int = 750 # VAE hidden layers dimension
vae_latent_dim: Optional[int] = None # VAE latent space, 2 * action_dim if None
beta: float = 0.5 # KL loss weight
vae_iterations: int = 100_000 # Number of VAE training updates
# SPOT
actor_init_w: Optional[float] = None # Actor head init parameter
critic_init_w: Optional[float] = None # Critic head init parameter
lambd: float = 1.0 # Support constraint weight
num_samples: int = 1 # Number of samples for density estimation
iwae: bool = False # Use IWAE loss
lambd_cool: bool = False # Cooling lambda during fine-tune
lambd_end: float = 0.2 # Minimal value of lambda
normalize: bool = False # Normalize states
normalize_reward: bool = True # Normalize reward
online_discount: float = 0.995 # Discount for online tuning
# Wandb logging
project: str = "CORL"
group: str = "SPOT-D4RL"
name: str = "SPOT"
def __post_init__(self):
self.name = f"{self.name}-{self.env}-{str(uuid.uuid4())[:8]}"
if self.checkpoints_path is not None:
self.checkpoints_path = os.path.join(self.checkpoints_path, self.name)
def soft_update(target: nn.Module, source: nn.Module, tau: float):
for target_param, source_param in zip(target.parameters(), source.parameters()):
target_param.data.copy_((1 - tau) * target_param.data + tau * source_param.data)
def compute_mean_std(states: np.ndarray, eps: float) -> Tuple[np.ndarray, np.ndarray]:
mean = states.mean(0)
std = states.std(0) + eps
return mean, std
def normalize_states(states: np.ndarray, mean: np.ndarray, std: np.ndarray):
return (states - mean) / std
def wrap_env(
env: gym.Env,
state_mean: Union[np.ndarray, float] = 0.0,
state_std: Union[np.ndarray, float] = 1.0,
reward_scale: float = 1.0,
) -> gym.Env:
# PEP 8: E731 do not assign a lambda expression, use a def
def normalize_state(state):
return (
state - state_mean
) / state_std # epsilon should be already added in std.
def scale_reward(reward):
# Please be careful, here reward is multiplied by scale!
return reward_scale * reward
env = gym.wrappers.TransformObservation(env, normalize_state)
if reward_scale != 1.0:
env = gym.wrappers.TransformReward(env, scale_reward)
return env
class ReplayBuffer:
def __init__(
self,
state_dim: int,
action_dim: int,
buffer_size: int,
device: str = "cpu",
):
self._buffer_size = buffer_size
self._pointer = 0
self._size = 0
self._states = torch.zeros(
(buffer_size, state_dim), dtype=torch.float32, device=device
)
self._actions = torch.zeros(
(buffer_size, action_dim), dtype=torch.float32, device=device
)
self._rewards = torch.zeros((buffer_size, 1), dtype=torch.float32, device=device)
self._next_states = torch.zeros(
(buffer_size, state_dim), dtype=torch.float32, device=device
)
self._dones = torch.zeros((buffer_size, 1), dtype=torch.float32, device=device)
self._device = device
def _to_tensor(self, data: np.ndarray) -> torch.Tensor:
return torch.tensor(data, dtype=torch.float32, device=self._device)
# Loads data in d4rl format, i.e. from Dict[str, np.array].
def load_d4rl_dataset(self, data: Dict[str, np.ndarray]):
if self._size != 0:
raise ValueError("Trying to load data into non-empty replay buffer")
n_transitions = data["observations"].shape[0]
if n_transitions > self._buffer_size:
raise ValueError(
"Replay buffer is smaller than the dataset you are trying to load!"
)
self._states[:n_transitions] = self._to_tensor(data["observations"])
self._actions[:n_transitions] = self._to_tensor(data["actions"])
self._rewards[:n_transitions] = self._to_tensor(data["rewards"][..., None])
self._next_states[:n_transitions] = self._to_tensor(data["next_observations"])
self._dones[:n_transitions] = self._to_tensor(data["terminals"][..., None])
self._size += n_transitions
self._pointer = min(self._size, n_transitions)
print(f"Dataset size: {n_transitions}")
def sample(self, batch_size: int) -> TensorBatch:
indices = np.random.randint(0, self._size, size=batch_size)
states = self._states[indices]
actions = self._actions[indices]
rewards = self._rewards[indices]
next_states = self._next_states[indices]
dones = self._dones[indices]
return [states, actions, rewards, next_states, dones]
def add_transition(
self,
state: np.ndarray,
action: np.ndarray,
reward: float,
next_state: np.ndarray,
done: bool,
):
# Use this method to add new data into the replay buffer during fine-tuning.
self._states[self._pointer] = self._to_tensor(state)
self._actions[self._pointer] = self._to_tensor(action)
self._rewards[self._pointer] = self._to_tensor(reward)
self._next_states[self._pointer] = self._to_tensor(next_state)
self._dones[self._pointer] = self._to_tensor(done)
self._pointer = (self._pointer + 1) % self._buffer_size
self._size = min(self._size + 1, self._buffer_size)
def set_env_seed(env: Optional[gym.Env], seed: int):
env.seed(seed)
env.action_space.seed(seed)
def set_seed(
seed: int, env: Optional[gym.Env] = None, deterministic_torch: bool = False
):
if env is not None:
set_env_seed(env, seed)
os.environ["PYTHONHASHSEED"] = str(seed)
np.random.seed(seed)
random.seed(seed)
torch.manual_seed(seed)
torch.use_deterministic_algorithms(deterministic_torch)
def wandb_init(config: dict) -> None:
wandb.init(
config=config,
project=config["project"],
group=config["group"],
name=config["name"],
id=str(uuid.uuid4()),
)
wandb.run.save()
def is_goal_reached(reward: float, info: Dict) -> bool:
if "goal_achieved" in info:
return info["goal_achieved"]
return reward > 0 # Assuming that reaching target is a positive reward
@torch.no_grad()
def eval_actor(
env: gym.Env, actor: nn.Module, device: str, n_episodes: int, seed: int
) -> Tuple[np.ndarray, np.ndarray]:
env.seed(seed)
actor.eval()
episode_rewards = []
successes = []
for _ in range(n_episodes):
state, done = env.reset(), False
episode_reward = 0.0
goal_achieved = False
while not done:
action = actor.act(state, device)
state, reward, done, env_infos = env.step(action)
episode_reward += reward
if not goal_achieved:
goal_achieved = is_goal_reached(reward, env_infos)
# Valid only for environments with goal
successes.append(float(goal_achieved))
episode_rewards.append(episode_reward)
actor.train()
return np.asarray(episode_rewards), np.mean(successes)
def return_reward_range(dataset: Dict, max_episode_steps: int) -> Tuple[float, float]:
returns, lengths = [], []
ep_ret, ep_len = 0.0, 0
for r, d in zip(dataset["rewards"], dataset["terminals"]):
ep_ret += float(r)
ep_len += 1
if d or ep_len == max_episode_steps:
returns.append(ep_ret)
lengths.append(ep_len)
ep_ret, ep_len = 0.0, 0
lengths.append(ep_len) # but still keep track of number of steps
assert sum(lengths) == len(dataset["rewards"])
return min(returns), max(returns)
def modify_reward(dataset: Dict, env_name: str, max_episode_steps: int = 1000) -> Dict:
if any(s in env_name for s in ("halfcheetah", "hopper", "walker2d")):
min_ret, max_ret = return_reward_range(dataset, max_episode_steps)
dataset["rewards"] /= max_ret - min_ret
dataset["rewards"] *= max_episode_steps
return {
"max_ret": max_ret,
"min_ret": min_ret,
"max_episode_steps": max_episode_steps,
}
elif "antmaze" in env_name:
dataset["rewards"] -= 1.0
return {}
def modify_reward_online(reward: float, env_name: str, **kwargs) -> float:
if any(s in env_name for s in ("halfcheetah", "hopper", "walker2d")):
reward /= kwargs["max_ret"] - kwargs["min_ret"]
reward *= kwargs["max_episode_steps"]
elif "antmaze" in env_name:
reward -= 1.0
return reward
def weights_init(m: nn.Module, init_w: float = 3e-3):
if isinstance(m, nn.Linear):
m.weight.data.uniform_(-init_w, init_w)
m.bias.data.uniform_(-init_w, init_w)
class VAE(nn.Module):
# Vanilla Variational Auto-Encoder
def __init__(
self,
state_dim: int,
action_dim: int,
latent_dim: int,
max_action: float,
hidden_dim: int = 750,
):
super(VAE, self).__init__()
if latent_dim is None:
latent_dim = 2 * action_dim
self.encoder_shared = nn.Sequential(
nn.Linear(state_dim + action_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
)
self.mean = nn.Linear(hidden_dim, latent_dim)
self.log_std = nn.Linear(hidden_dim, latent_dim)
self.decoder = nn.Sequential(
nn.Linear(state_dim + latent_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, action_dim),
nn.Tanh(),
)
self.max_action = max_action
self.latent_dim = latent_dim
def forward(
self,
state: torch.Tensor,
action: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
mean, std = self.encode(state, action)
z = mean + std * torch.randn_like(std)
u = self.decode(state, z)
return u, mean, std
def importance_sampling_estimator(
self,
state: torch.Tensor,
action: torch.Tensor,
beta: float,
num_samples: int = 500,
) -> torch.Tensor:
# * num_samples correspond to num of samples L in the paper
# * note that for exact value for \hat \log \pi_\beta in the paper
# we also need **an expection over L samples**
mean, std = self.encode(state, action)
mean_enc = mean.repeat(num_samples, 1, 1).permute(1, 0, 2) # [B x S x D]
std_enc = std.repeat(num_samples, 1, 1).permute(1, 0, 2) # [B x S x D]
z = mean_enc + std_enc * torch.randn_like(std_enc) # [B x S x D]
state = state.repeat(num_samples, 1, 1).permute(1, 0, 2) # [B x S x C]
action = action.repeat(num_samples, 1, 1).permute(1, 0, 2) # [B x S x C]
mean_dec = self.decode(state, z)
std_dec = np.sqrt(beta / 4)
# Find q(z|x)
log_qzx = td.Normal(loc=mean_enc, scale=std_enc).log_prob(z)
# Find p(z)
mu_prior = torch.zeros_like(z).to(self.device)
std_prior = torch.ones_like(z).to(self.device)
log_pz = td.Normal(loc=mu_prior, scale=std_prior).log_prob(z)
# Find p(x|z)
std_dec = torch.ones_like(mean_dec).to(self.device) * std_dec
log_pxz = td.Normal(loc=mean_dec, scale=std_dec).log_prob(action)
w = log_pxz.sum(-1) + log_pz.sum(-1) - log_qzx.sum(-1)
ll = w.logsumexp(dim=-1) - np.log(num_samples)
return ll
def encode(
self,
state: torch.Tensor,
action: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
z = self.encoder_shared(torch.cat([state, action], -1))
mean = self.mean(z)
# Clamped for numerical stability
log_std = self.log_std(z).clamp(-4, 15)
std = torch.exp(log_std)
return mean, std
def decode(
self,
state: torch.Tensor,
z: torch.Tensor = None,
) -> torch.Tensor:
# When sampling from the VAE, the latent vector is clipped to [-0.5, 0.5]
if z is None:
z = (
torch.randn((state.shape[0], self.latent_dim))
.to(self.device)
.clamp(-0.5, 0.5)
)
x = torch.cat([state, z], -1)
return self.max_action * self.decoder(x)
class Actor(nn.Module):
def __init__(
self,
state_dim: int,
action_dim: int,
max_action: float,
init_w: Optional[float] = None,
):
super(Actor, self).__init__()
head = nn.Linear(256, action_dim)
if init_w is not None:
weights_init(head, init_w)
self.net = nn.Sequential(
nn.Linear(state_dim, 256),
nn.ReLU(),
nn.Linear(256, 256),
nn.ReLU(),
head,
nn.Tanh(),
)
self.max_action = max_action
def forward(self, state: torch.Tensor) -> torch.Tensor:
return self.max_action * self.net(state)
@torch.no_grad()
def act(self, state: np.ndarray, device: str = "cpu") -> np.ndarray:
state = torch.tensor(state.reshape(1, -1), device=device, dtype=torch.float32)
return self(state).cpu().data.numpy().flatten()
class Critic(nn.Module):
def __init__(self, state_dim: int, action_dim: int, init_w: Optional[float] = None):
super(Critic, self).__init__()
head = nn.Linear(256, 1)
if init_w is not None:
weights_init(head, init_w)
self.net = nn.Sequential(
nn.Linear(state_dim + action_dim, 256),
nn.ReLU(),
nn.Linear(256, 256),
nn.ReLU(),
head,
)
def forward(self, state: torch.Tensor, action: torch.Tensor) -> torch.Tensor:
sa = torch.cat([state, action], 1)
return self.net(sa)
class SPOT:
def __init__(
self,
max_action: float,
actor: nn.Module,
actor_optimizer: torch.optim.Optimizer,
critic_1: nn.Module,
critic_1_optimizer: torch.optim.Optimizer,
critic_2: nn.Module,
critic_2_optimizer: torch.optim.Optimizer,
vae: nn.Module,
vae_optimizer: torch.optim.Optimizer,
discount: float = 0.99,
tau: float = 0.005,
policy_noise: float = 0.2,
noise_clip: float = 0.5,
policy_freq: int = 2,
beta: float = 0.5,
lambd: float = 1.0,
num_samples: int = 1,
iwae: bool = False,
lambd_cool: bool = False,
lambd_end: float = 0.2,
max_online_steps: int = 1_000_000,
device: str = "cpu",
):
self.actor = actor
self.actor_target = copy.deepcopy(actor)
self.actor_optimizer = actor_optimizer
self.critic_1 = critic_1
self.critic_1_target = copy.deepcopy(critic_1)
self.critic_1_optimizer = critic_1_optimizer
self.critic_2 = critic_2
self.critic_2_target = copy.deepcopy(critic_2)
self.critic_2_optimizer = critic_2_optimizer
self.vae = vae
self.vae_optimizer = vae_optimizer
self.max_action = max_action
self.discount = discount
self.tau = tau
self.policy_noise = policy_noise
self.noise_clip = noise_clip
self.policy_freq = policy_freq
self.beta = beta
self.lambd = lambd
self.num_samples = num_samples
self.iwae = iwae
self.lambd_cool = lambd_cool
self.lambd_end = lambd_end
self.max_online_steps = max_online_steps
self.is_online = False
self.online_it = 0
self.total_it = 0
self.device = device
def elbo_loss(
self,
state: torch.Tensor,
action: torch.Tensor,
beta: float,
num_samples: int = 1,
) -> torch.Tensor:
"""
Note: elbo_loss one is proportional to elbo_estimator
i.e. there exist a>0 and b, elbo_loss = a * (-elbo_estimator) + b
"""
mean, std = self.vae.encode(state, action)
mean_s = mean.repeat(num_samples, 1, 1).permute(1, 0, 2) # [B x S x D]
std_s = std.repeat(num_samples, 1, 1).permute(1, 0, 2) # [B x S x D]
z = mean_s + std_s * torch.randn_like(std_s)
state = state.repeat(num_samples, 1, 1).permute(1, 0, 2) # [B x S x C]
action = action.repeat(num_samples, 1, 1).permute(1, 0, 2) # [B x S x C]
u = self.vae.decode(state, z)
recon_loss = ((u - action) ** 2).mean(dim=(1, 2))
KL_loss = -0.5 * (1 + torch.log(std.pow(2)) - mean.pow(2) - std.pow(2)).mean(-1)
vae_loss = recon_loss + beta * KL_loss
return vae_loss
def iwae_loss(
self,
state: torch.Tensor,
action: torch.Tensor,
beta: float,
num_samples: int = 10,
) -> torch.Tensor:
ll = self.vae.importance_sampling_estimator(state, action, beta, num_samples)
return -ll
def vae_train(self, batch: TensorBatch) -> Dict[str, float]:
log_dict = {}
self.total_it += 1
state, action, _, _, _ = batch
# Variational Auto-Encoder Training
recon, mean, std = self.vae(state, action)
recon_loss = F.mse_loss(recon, action)
KL_loss = -0.5 * (1 + torch.log(std.pow(2)) - mean.pow(2) - std.pow(2)).mean()
vae_loss = recon_loss + self.beta * KL_loss
self.vae_optimizer.zero_grad()
vae_loss.backward()
self.vae_optimizer.step()
log_dict["VAE/reconstruction_loss"] = recon_loss.item()
log_dict["VAE/KL_loss"] = KL_loss.item()
log_dict["VAE/vae_loss"] = vae_loss.item()
return log_dict
def train(self, batch: TensorBatch) -> Dict[str, float]:
log_dict = {}
self.total_it += 1
if self.is_online:
self.online_it += 1
state, action, reward, next_state, done = batch
not_done = 1 - done
with torch.no_grad():
# Select action according to actor and add clipped noise
noise = (torch.randn_like(action) * self.policy_noise).clamp(
-self.noise_clip, self.noise_clip
)
next_action = (self.actor_target(next_state) + noise).clamp(
-self.max_action, self.max_action
)
# Compute the target Q value
target_q1 = self.critic_1_target(next_state, next_action)
target_q2 = self.critic_2_target(next_state, next_action)
target_q = torch.min(target_q1, target_q2)
target_q = reward + not_done * self.discount * target_q
# Get current Q estimates
current_q1 = self.critic_1(state, action)
current_q2 = self.critic_2(state, action)
# Compute critic loss
critic_loss = F.mse_loss(current_q1, target_q) + F.mse_loss(current_q2, target_q)
log_dict["critic_loss"] = critic_loss.item()
# Optimize the critic
self.critic_1_optimizer.zero_grad()
self.critic_2_optimizer.zero_grad()
critic_loss.backward()
self.critic_1_optimizer.step()
self.critic_2_optimizer.step()
# Delayed actor updates
if self.total_it % self.policy_freq == 0:
# Compute actor loss
pi = self.actor(state)
q = self.critic_1(state, pi)
if self.iwae:
neg_log_beta = self.iwae_loss(state, pi, self.beta, self.num_samples)
else:
neg_log_beta = self.elbo_loss(state, pi, self.beta, self.num_samples)
if self.lambd_cool:
lambd = self.lambd * max(
self.lambd_end, (1.0 - self.online_it / self.max_online_steps)
)
else:
lambd = self.lambd
norm_q = 1 / q.abs().mean().detach()
actor_loss = -norm_q * q.mean() + lambd * neg_log_beta.mean()
log_dict["actor_loss"] = actor_loss.item()
log_dict["neg_log_beta_mean"] = neg_log_beta.mean().item()
log_dict["neg_log_beta_max"] = neg_log_beta.max().item()
log_dict["lambd"] = lambd
# Optimize the actor
self.actor_optimizer.zero_grad()
actor_loss.backward()
self.actor_optimizer.step()
# Update the frozen target models
soft_update(self.critic_1_target, self.critic_1, self.tau)
soft_update(self.critic_2_target, self.critic_2, self.tau)
soft_update(self.actor_target, self.actor, self.tau)
return log_dict
def state_dict(self) -> Dict[str, Any]:
return {
"vae": self.vae.state_dict(),
"vae_optimizer": self.vae_optimizer.state_dict(),
"critic_1": self.critic_1.state_dict(),
"critic_1_optimizer": self.critic_1_optimizer.state_dict(),
"critic_2": self.critic_2.state_dict(),
"critic_2_optimizer": self.critic_2_optimizer.state_dict(),
"actor": self.actor.state_dict(),
"actor_optimizer": self.actor_optimizer.state_dict(),
"total_it": self.total_it,
}
def load_state_dict(self, state_dict: Dict[str, Any]):
self.vae.load_state_dict(state_dict["vae"])
self.vae_optimizer.load_state_dict(state_dict["vae_optimizer"])
self.critic_1.load_state_dict(state_dict["critic_1"])
self.critic_1_optimizer.load_state_dict(state_dict["critic_1_optimizer"])
self.critic_1_target = copy.deepcopy(self.critic_1)
self.critic_2.load_state_dict(state_dict["critic_2"])
self.critic_2_optimizer.load_state_dict(state_dict["critic_2_optimizer"])
self.critic_2_target = copy.deepcopy(self.critic_2)
self.actor.load_state_dict(state_dict["actor"])
self.actor_optimizer.load_state_dict(state_dict["actor_optimizer"])
self.actor_target = copy.deepcopy(self.actor)
self.total_it = state_dict["total_it"]
@pyrallis.wrap()
def train(config: TrainConfig):
env = gym.make(config.env)
eval_env = gym.make(config.env)
is_env_with_goal = config.env.startswith(ENVS_WITH_GOAL)
max_steps = env._max_episode_steps
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.shape[0]
dataset = d4rl.qlearning_dataset(env)
reward_mod_dict = {}
if config.normalize_reward:
reward_mod_dict = modify_reward(dataset, config.env)
if config.normalize:
state_mean, state_std = compute_mean_std(dataset["observations"], eps=1e-3)
else:
state_mean, state_std = 0, 1
dataset["observations"] = normalize_states(
dataset["observations"], state_mean, state_std
)
dataset["next_observations"] = normalize_states(
dataset["next_observations"], state_mean, state_std
)
env = wrap_env(env, state_mean=state_mean, state_std=state_std)
eval_env = wrap_env(eval_env, state_mean=state_mean, state_std=state_std)
replay_buffer = ReplayBuffer(
state_dim,
action_dim,
config.buffer_size,
config.device,
)
replay_buffer.load_d4rl_dataset(dataset)
max_action = float(env.action_space.high[0])
if config.checkpoints_path is not None:
print(f"Checkpoints path: {config.checkpoints_path}")
os.makedirs(config.checkpoints_path, exist_ok=True)
with open(os.path.join(config.checkpoints_path, "config.yaml"), "w") as f:
pyrallis.dump(config, f)
# Set seeds
seed = config.seed
set_seed(seed, env)
set_env_seed(eval_env, config.eval_seed)
vae = VAE(
state_dim, action_dim, config.vae_latent_dim, max_action, config.vae_hidden_dim
).to(config.device)
vae_optimizer = torch.optim.Adam(vae.parameters(), lr=config.vae_lr)
actor = Actor(state_dim, action_dim, max_action, config.actor_init_w).to(
config.device
)
actor_optimizer = torch.optim.Adam(actor.parameters(), lr=config.actor_lr)
critic_1 = Critic(state_dim, action_dim, config.critic_init_w).to(config.device)
critic_1_optimizer = torch.optim.Adam(critic_1.parameters(), lr=config.critic_lr)
critic_2 = Critic(state_dim, action_dim, config.critic_init_w).to(config.device)
critic_2_optimizer = torch.optim.Adam(critic_2.parameters(), lr=config.critic_lr)
kwargs = {
"max_action": max_action,
"vae": vae,
"vae_optimizer": vae_optimizer,
"actor": actor,
"actor_optimizer": actor_optimizer,
"critic_1": critic_1,
"critic_1_optimizer": critic_1_optimizer,
"critic_2": critic_2,
"critic_2_optimizer": critic_2_optimizer,
"discount": config.discount,
"tau": config.tau,
"device": config.device,
# TD3
"policy_noise": config.policy_noise * max_action,
"noise_clip": config.noise_clip * max_action,
"policy_freq": config.policy_freq,
# SPOT
"beta": config.beta,
"lambd": config.lambd,
"num_samples": config.num_samples,
"iwae": config.iwae,
"lambd_cool": config.lambd_cool,
"lambd_end": config.lambd_end,
"max_online_steps": config.online_iterations,
}
print("---------------------------------------")
print(f"Training SPOT, Env: {config.env}, Seed: {seed}")
print("---------------------------------------")
# Initialize actor
trainer = SPOT(**kwargs)
if config.load_model != "":
policy_file = Path(config.load_model)
trainer.load_state_dict(torch.load(policy_file))
actor = trainer.actor
wandb_init(asdict(config))
evaluations = []
print("Training VAE")
for t in range(int(config.vae_iterations)):
batch = replay_buffer.sample(config.batch_size)
batch = [b.to(config.device) for b in batch]
log_dict = trainer.vae_train(batch)
log_dict["vae_iter"] = t
wandb.log(log_dict, step=trainer.total_it)
vae.eval()
state, done = env.reset(), False
episode_return = 0
episode_step = 0
goal_achieved = False
eval_successes = []
train_successes = []
print("Offline pretraining")
for t in range(int(config.offline_iterations) + int(config.online_iterations)):
if t == config.offline_iterations:
print("Online tuning")
trainer.is_online = True
trainer.discount = config.online_discount
# Resetting optimizers
trainer.actor_optimizer = torch.optim.Adam(
actor.parameters(), lr=config.actor_lr
)
trainer.critic_1_optimizer = torch.optim.Adam(
critic_1.parameters(), lr=config.critic_lr
)
trainer.critic_2_optimizer = torch.optim.Adam(
critic_2.parameters(), lr=config.critic_lr
)
online_log = {}
if t >= config.offline_iterations:
episode_step += 1
action = actor(
torch.tensor(
state.reshape(1, -1), device=config.device, dtype=torch.float32
)
)
noise = (torch.randn_like(action) * config.expl_noise).clamp(
-config.noise_clip, config.noise_clip
)
action += noise
action = torch.clamp(max_action * action, -max_action, max_action)
action = action.cpu().data.numpy().flatten()
next_state, reward, done, env_infos = env.step(action)
if not goal_achieved:
goal_achieved = is_goal_reached(reward, env_infos)
episode_return += reward
real_done = False # Episode can timeout which is different from done
if done and episode_step < max_steps:
real_done = True
if config.normalize_reward:
reward = modify_reward_online(reward, config.env, **reward_mod_dict)
replay_buffer.add_transition(state, action, reward, next_state, real_done)
state = next_state
if done:
state, done = env.reset(), False
# Valid only for envs with goal, e.g. AntMaze, Adroit
if is_env_with_goal:
train_successes.append(goal_achieved)
online_log["train/regret"] = np.mean(1 - np.array(train_successes))
online_log["train/is_success"] = float(goal_achieved)
online_log["train/episode_return"] = episode_return
normalized_return = eval_env.get_normalized_score(episode_return)
online_log["train/d4rl_normalized_episode_return"] = (
normalized_return * 100.0
)
online_log["train/episode_length"] = episode_step
episode_return = 0
episode_step = 0
goal_achieved = False
batch = replay_buffer.sample(config.batch_size)
batch = [b.to(config.device) for b in batch]
log_dict = trainer.train(batch)
log_dict["offline_iter" if t < config.offline_iterations else "online_iter"] = (
t if t < config.offline_iterations else t - config.offline_iterations
)
log_dict.update(online_log)
wandb.log(log_dict, step=trainer.total_it)
# Evaluate episode
if (t + 1) % config.eval_freq == 0:
print(f"Time steps: {t + 1}")
eval_scores, success_rate = eval_actor(
eval_env,
actor,
device=config.device,
n_episodes=config.n_episodes,
seed=config.seed,
)
eval_score = eval_scores.mean()
eval_log = {}
normalized = eval_env.get_normalized_score(np.mean(eval_scores))
# Valid only for envs with goal, e.g. AntMaze, Adroit
if t >= config.offline_iterations and is_env_with_goal:
eval_successes.append(success_rate)
eval_log["eval/regret"] = np.mean(1 - np.array(train_successes))
eval_log["eval/success_rate"] = success_rate
normalized_eval_score = normalized * 100.0
eval_log["eval/d4rl_normalized_score"] = normalized_eval_score
evaluations.append(normalized_eval_score)
print("---------------------------------------")
print(
f"Evaluation over {config.n_episodes} episodes: "
f"{eval_score:.3f} , D4RL score: {normalized_eval_score:.3f}"
)
print("---------------------------------------")
if config.checkpoints_path is not None:
torch.save(
trainer.state_dict(),
os.path.join(config.checkpoints_path, f"checkpoint_{t}.pt"),
)
wandb.log(eval_log, step=trainer.total_it)
if __name__ == "__main__":
train()