-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathREINFORCE.py
101 lines (80 loc) · 2.6 KB
/
REINFORCE.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import gym
import numpy as np
from itertools import count
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt
from torch.distributions import Categorical
lr = 0.0003 #SGD: 0.00003, Adam: 0.0003
gamma = 1.0
seed = 543
log_interval = 10
episodes = 5
env = gym.make("CartPole-v1")
f_star = -500
class Policy(nn.Module):
def __init__(self):
super(Policy, self).__init__()
self.affine1 = nn.Linear(4, 128)
self.affine2 = nn.Linear(128, 2)
self.saved_log_probs = []
self.rewards = []
self.masks = []
def forward(self, x):
x = F.relu(self.affine1(x))
action_scores = self.affine2(x)
return F.softmax(action_scores, dim=1)
policy = Policy()
optimizer = optim.SGD(policy.parameters(), lr=lr)
def select_action(state):
state = torch.from_numpy(state).float().unsqueeze(0)
probs = policy(state)
m = Categorical(probs)
action = m.sample()
policy.saved_log_probs.append(m.log_prob(action))
return action.item()
def finish_episode():
# print()
R = 0
policy_loss = []
rewards = []
for i in range(1, len(policy.rewards)+1):
R = policy.rewards[-i] + gamma * R * policy.masks[-i]
rewards.insert(0, R)
rewards = torch.tensor(rewards)
for log_prob, reward in zip(policy.saved_log_probs, rewards):
policy_loss.append(-log_prob * reward)
policy_loss = torch.cat(policy_loss).sum() / episodes
optimizer.zero_grad()
policy_loss.backward()
optimizer.step()
del policy.rewards[:]
del policy.saved_log_probs[:]
def main():
running_reward = 10
for i_episode in count(1):
state, info = env.reset() # seed=args.seed)
for episode in range(episodes):
for t in range(10000): # Don't infinite loop while learning
action = select_action(state)
state, reward, done, truncated, _ = env.step(action)
policy.rewards.append(reward)
policy.masks.append(1)
if done or truncated:
state, info = env.reset() # seed=args.seed)
policy.masks.append(0)
break
running_reward = running_reward * 0.99 + t * 0.01
finish_episode()
if i_episode % log_interval == 0:
print(
"Episode {}\tLast length: {:5d}\tAverage length: {:.2f}".format(
i_episode, t, running_reward
)
)
if i_episode == 5000:
break
if __name__ == "__main__":
main()