Skip to content

Commit 54c9f0d

Browse files
author
Xingdong Zuo
committed
Update
Former-commit-id: d615ccd64a2089a302e61651b9761498d3199b44 [formerly f131ab52da6f638cd55a494f9b3ec7d54f6baaa5] Former-commit-id: ec4abd8754806888f3a47ac67985a8011a751a2b
1 parent 78ab69f commit 54c9f0d

File tree

2 files changed

+290
-10
lines changed

2 files changed

+290
-10
lines changed

examples/policy_gradient/vpg/main.ipynb

Lines changed: 290 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,294 @@
11
{
22
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": null,
6+
"metadata": {},
7+
"outputs": [],
8+
"source": []
9+
},
10+
{
11+
"cell_type": "code",
12+
"execution_count": null,
13+
"metadata": {},
14+
"outputs": [],
15+
"source": []
16+
},
17+
{
18+
"cell_type": "code",
19+
"execution_count": 1,
20+
"metadata": {},
21+
"outputs": [
22+
{
23+
"ename": "ImportError",
24+
"evalue": "libcudart.so.9.2: cannot open shared object file: No such file or directory",
25+
"output_type": "error",
26+
"traceback": [
27+
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
28+
"\u001b[0;31mImportError\u001b[0m Traceback (most recent call last)",
29+
"\u001b[0;32m<ipython-input-1-587e5575a1c6>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mnumpy\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0;32mimport\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 4\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0moptim\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0moptim\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnn\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mnn\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
30+
"\u001b[0;32m~/anaconda3/envs/RL/lib/python3.7/site-packages/torch/__init__.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 82\u001b[0m \u001b[0;32mpass\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 83\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 84\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_C\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 85\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 86\u001b[0m __all__ += [name for name in dir(_C)\n",
31+
"\u001b[0;31mImportError\u001b[0m: libcudart.so.9.2: cannot open shared object file: No such file or directory"
32+
]
33+
}
34+
],
35+
"source": [
36+
"import numpy as np\n",
37+
"\n",
38+
"import torch\n",
39+
"import torch.optim as optim\n",
40+
"import torch.nn as nn\n",
41+
"import torch.nn.functional as F\n",
42+
"\n",
43+
"from torch.nn.utils import clip_grad_norm_\n",
44+
"\n",
45+
"from lagom.networks import BaseNetwork\n",
46+
"from lagom.networks import make_fc\n",
47+
"from lagom.networks import ortho_init\n",
48+
"from lagom.networks import linear_lr_scheduler\n",
49+
"\n",
50+
"from lagom.policies import BasePolicy\n",
51+
"from lagom.policies import CategoricalHead\n",
52+
"from lagom.policies import DiagGaussianHead\n",
53+
"from lagom.policies import constraint_action\n",
54+
"\n",
55+
"from lagom.value\n",
56+
"\n",
57+
"from lagom.transform import Standardize\n",
58+
"\n",
59+
"from lagom.agents import BaseAgent\n",
60+
"\n",
61+
"\n",
62+
"class MLP(BaseNetwork):\n",
63+
" def make_params(self, config):\n",
64+
" self.feature_layers = make_fc(self.env_spec.observation_space.flat_dim, config['network.hidden_sizes'])\n",
65+
" \n",
66+
" def init_params(self, config):\n",
67+
" for layer in self.feature_layers:\n",
68+
" ortho_init(layer, nonlinearity='tanh', constant_bias=0.0)\n",
69+
" \n",
70+
" def reset(self, config, **kwargs):\n",
71+
" pass\n",
72+
" \n",
73+
" def forward(self, x):\n",
74+
" for layer in self.feature_layers:\n",
75+
" x = torch.tanh(layer(x))\n",
76+
" \n",
77+
" return x\n",
78+
" \n",
79+
" \n",
80+
"class Policy(BasePolicy):\n",
81+
" def make_networks(self, config):\n",
82+
" self.feature_network = MLP(config, self.device, env_spec=self.env_spec)\n",
83+
" feature_dim = config['network.hidden_sizes'][-1]\n",
84+
" \n",
85+
" if self.env_spec.control_type == 'Discrete':\n",
86+
" self.action_head = CategoricalHead(config, self.device, feature_dim, self.env_spec)\n",
87+
" elif self.env_spec.control_type == 'Continuous':\n",
88+
" self.action_head = DiagGaussianHead(config, \n",
89+
" self.device, \n",
90+
" feature_dim, \n",
91+
" self.env_spec, \n",
92+
" min_std=config['agent.min_std'], \n",
93+
" std_style=config['agent.std_style'], \n",
94+
" constant_std=config['agent.constant_std'],\n",
95+
" std_state_dependent=config['agent.std_state_dependent'],\n",
96+
" init_std=config['agent.init_std'])\n",
97+
" \n",
98+
" @property\n",
99+
" def recurrent(self):\n",
100+
" return False\n",
101+
" \n",
102+
" def reset(self, config, **kwargs):\n",
103+
" pass\n",
104+
"\n",
105+
" def __call__(self, x, out_keys=['action'], info={}, **kwargs):\n",
106+
" out = {}\n",
107+
" \n",
108+
" features = self.feature_network(x)\n",
109+
" action_dist = self.action_head(features)\n",
110+
" \n",
111+
" action = action_dist.sample().detach()################################\n",
112+
" out['action'] = action\n",
113+
" \n",
114+
" if 'action_logprob' in out_keys:\n",
115+
" out['action_logprob'] = action_dist.log_prob(action)\n",
116+
" if 'entropy' in out_keys:\n",
117+
" out['entropy'] = action_dist.entropy()\n",
118+
" if 'perplexity' in out_keys:\n",
119+
" out['perplexity'] = action_dist.perplexity()\n",
120+
" \n",
121+
" return out\n",
122+
" \n",
123+
"\n",
124+
"class Agent(BaseAgent):\n",
125+
" r\"\"\"REINFORCE (no baseline). \"\"\"\n",
126+
" def make_modules(self, config):\n",
127+
" self.policy = Policy(config, self.env_spec, self.device)\n",
128+
" \n",
129+
" def prepare(self, config, **kwargs):\n",
130+
" self.total_T = 0\n",
131+
" self.optimizer = optim.Adam(self.policy.parameters(), lr=config['algo.lr'])\n",
132+
" if config['algo.use_lr_scheduler']:\n",
133+
" if 'train.iter' in config:\n",
134+
" self.lr_scheduler = linear_lr_scheduler(self.optimizer, config['train.iter'], 'iteration-based')\n",
135+
" elif 'train.timestep' in config:\n",
136+
" self.lr_scheduler = linear_lr_scheduler(self.optimizer, config['train.timestep']+1, 'timestep-based')\n",
137+
" else:\n",
138+
" self.lr_scheduler = None\n",
139+
" \n",
140+
"\n",
141+
" def reset(self, config, **kwargs):\n",
142+
" pass\n",
143+
"\n",
144+
" def choose_action(self, obs, info={}):\n",
145+
" obs = torch.from_numpy(np.asarray(obs)).float().to(self.device)\n",
146+
" \n",
147+
" out = self.policy(obs, out_keys=['action', 'action_logprob', 'entropy'], info=info)\n",
148+
" \n",
149+
" # sanity check for NaN\n",
150+
" if torch.any(torch.isnan(out['action'])):\n",
151+
" while True:\n",
152+
" print('NaN !')\n",
153+
" if self.env_spec.control_type == 'Continuous':\n",
154+
" out['action'] = constraint_action(self.env_spec, out['action'])\n",
155+
" \n",
156+
" return out\n",
157+
"\n",
158+
" def learn(self, D, info={}):\n",
159+
" batch_policy_loss = []\n",
160+
" batch_entropy_loss = []\n",
161+
" batch_total_loss = []\n",
162+
" \n",
163+
" for trajectory in D:\n",
164+
" logprobs = trajectory.all_info('action_logprob')\n",
165+
" entropies = trajectory.all_info('entropy')\n",
166+
" Qs = trajectory.all_discounted_returns(self.config['algo.gamma'])\n",
167+
" \n",
168+
" # Standardize: encourage/discourage half of performed actions\n",
169+
" if self.config['agent.standardize_Q']:\n",
170+
" Qs = Standardize()(Qs, -1).tolist()\n",
171+
" \n",
172+
" policy_loss = []\n",
173+
" entropy_loss = []\n",
174+
" for logprob, entropy, Q in zip(logprobs, entropies, Qs):\n",
175+
" policy_loss.append(-logprob*Q)\n",
176+
" entropy_loss.append(-entropy)\n",
177+
" \n",
178+
" policy_loss = torch.stack(policy_loss).mean()\n",
179+
" entropy_loss = torch.stack(entropy_loss).mean()\n",
180+
" \n",
181+
" entropy_coef = self.config['agent.entropy_coef']\n",
182+
" total_loss = policy_loss + entropy_coef*entropy_loss\n",
183+
" \n",
184+
" batch_policy_loss.append(policy_loss)\n",
185+
" batch_entropy_loss.append(entropy_loss)\n",
186+
" batch_total_loss.append(total_loss)\n",
187+
" \n",
188+
" policy_loss = torch.stack(batch_policy_loss).mean()\n",
189+
" entropy_loss = torch.stack(batch_entropy_loss).mean()\n",
190+
" loss = torch.stack(batch_total_loss).mean()\n",
191+
" \n",
192+
" self.optimizer.zero_grad()\n",
193+
" loss.backward()\n",
194+
" \n",
195+
" if self.config['agent.max_grad_norm'] is not None:\n",
196+
" clip_grad_norm_(self.parameters(), self.config['agent.max_grad_norm'])\n",
197+
" \n",
198+
" if self.lr_scheduler is not None:\n",
199+
" if self.lr_scheduler.mode == 'iteration-based':\n",
200+
" self.lr_scheduler.step()\n",
201+
" elif self.lr_scheduler.mode == 'timestep-based':\n",
202+
" self.lr_scheduler.step(self.total_T)\n",
203+
"\n",
204+
" self.optimizer.step()\n",
205+
" \n",
206+
" self.total_T += sum([trajectory.T for trajectory in D])\n",
207+
" \n",
208+
" out = {}\n",
209+
" out['loss'] = loss.item()\n",
210+
" out['policy_loss'] = policy_loss.item()\n",
211+
" out['entropy_loss'] = entropy_loss.item()\n",
212+
" if self.lr_scheduler is not None:\n",
213+
" out['current_lr'] = self.lr_scheduler.get_lr()\n",
214+
"\n",
215+
" return out\n",
216+
" \n",
217+
" @property\n",
218+
" def recurrent(self):\n",
219+
" pass\n"
220+
]
221+
},
222+
{
223+
"cell_type": "code",
224+
"execution_count": null,
225+
"metadata": {},
226+
"outputs": [],
227+
"source": []
228+
},
229+
{
230+
"cell_type": "code",
231+
"execution_count": null,
232+
"metadata": {},
233+
"outputs": [],
234+
"source": []
235+
},
236+
{
237+
"cell_type": "code",
238+
"execution_count": null,
239+
"metadata": {},
240+
"outputs": [],
241+
"source": []
242+
},
243+
{
244+
"cell_type": "code",
245+
"execution_count": null,
246+
"metadata": {},
247+
"outputs": [],
248+
"source": []
249+
},
250+
{
251+
"cell_type": "code",
252+
"execution_count": null,
253+
"metadata": {},
254+
"outputs": [],
255+
"source": []
256+
},
257+
{
258+
"cell_type": "code",
259+
"execution_count": null,
260+
"metadata": {},
261+
"outputs": [],
262+
"source": []
263+
},
264+
{
265+
"cell_type": "code",
266+
"execution_count": null,
267+
"metadata": {},
268+
"outputs": [],
269+
"source": []
270+
},
271+
{
272+
"cell_type": "code",
273+
"execution_count": null,
274+
"metadata": {},
275+
"outputs": [],
276+
"source": []
277+
},
278+
{
279+
"cell_type": "code",
280+
"execution_count": null,
281+
"metadata": {},
282+
"outputs": [],
283+
"source": []
284+
},
285+
{
286+
"cell_type": "code",
287+
"execution_count": null,
288+
"metadata": {},
289+
"outputs": [],
290+
"source": []
291+
},
3292
{
4293
"cell_type": "code",
5294
"execution_count": 24,
@@ -226,7 +515,7 @@
226515
"name": "python",
227516
"nbconvert_exporter": "python",
228517
"pygments_lexer": "ipython3",
229-
"version": "3.6.6"
518+
"version": "3.7.0"
230519
}
231520
},
232521
"nbformat": 4,

examples/policy_gradient/vpg/model.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,4 @@
1-
import numpy as np
21

3-
import torch
4-
import torch.nn as nn
5-
import torch.nn.functional as F
6-
from torch.nn.utils import clip_grad_norm_
7-
8-
from .base_agent import BaseAgent
9-
10-
from lagom.core.transform import Standardize
112

123

134
class VPGAgent(BaseAgent):

0 commit comments

Comments
 (0)