-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathutil.py
136 lines (111 loc) · 3.77 KB
/
util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import os
import subprocess
import sys
import importlib
import inspect
import functools
import tensorflow as tf
import numpy as np
from common import tf_util as U
def store_args(method):
"""Stores provided method args as instance attributes.
"""
argspec = inspect.getfullargspec(method)
defaults = {}
if argspec.defaults is not None:
defaults = dict(
zip(argspec.args[-len(argspec.defaults):], argspec.defaults))
if argspec.kwonlydefaults is not None:
defaults.update(argspec.kwonlydefaults)
arg_names = argspec.args[1:]
@functools.wraps(method)
def wrapper(*positional_args, **keyword_args):
self = positional_args[0]
# Get default arg values
args = defaults.copy()
# Add provided arg values
for name, value in zip(arg_names, positional_args[1:]):
args[name] = value
args.update(keyword_args)
self.__dict__.update(args)
return method(*positional_args, **keyword_args)
return wrapper
def import_function(spec):
"""Import a function identified by a string like "pkg.module:fn_name".
"""
mod_name, fn_name = spec.split(':')
module = importlib.import_module(mod_name)
fn = getattr(module, fn_name)
return fn
def flatten_grads(var_list, grads):
"""Flattens a variables and their gradients.
"""
return tf.concat(
[tf.reshape(grad, [U.numel(v)]) for (v, grad) in zip(var_list, grads)],
0)
def nn(input, layers_sizes, reuse=None, flatten=False, name=""):
"""Creates a simple neural network
"""
for i, size in enumerate(layers_sizes):
activation = tf.nn.relu if i < len(layers_sizes) - 1 else None
input = tf.layers.dense(
inputs=input,
units=size,
kernel_initializer=tf.contrib.layers.xavier_initializer(),
reuse=reuse,
name=name + '_' + str(i))
if activation:
input = activation(input)
if flatten:
assert layers_sizes[-1] == 1
input = tf.reshape(input, [-1])
return input
def install_mpi_excepthook():
import sys
from mpi4py import MPI
old_hook = sys.excepthook
def new_hook(a, b, c):
old_hook(a, b, c)
sys.stdout.flush()
sys.stderr.flush()
MPI.COMM_WORLD.Abort()
sys.excepthook = new_hook
def mpi_fork(n):
"""Re-launches the current script with workers
Returns "parent" for original parent, "child" for MPI children
"""
if n <= 1:
return "child"
if os.getenv("IN_MPI") is None:
env = os.environ.copy()
env.update(MKL_NUM_THREADS="1", OMP_NUM_THREADS="1", IN_MPI="1")
# "-bind-to core" is crucial for good performance
args = ["mpirun", "-np", str(n), "-bind-to", "core", sys.executable]
args += sys.argv
subprocess.check_call(args, env=env)
return "parent"
else:
install_mpi_excepthook()
return "child"
def convert_episode_to_batch_major(episode):
"""Converts an episode to have the batch dimension in the major (first)
dimension.
"""
episode_batch = {}
for key in episode.keys():
val = np.array(episode[key]).copy()
# make inputs batch-major instead of time-major
episode_batch[key] = val.swapaxes(0, 1)
return episode_batch
def transitions_in_episode_batch(episode_batch):
"""Number of transitions in a given episode batch.
"""
shape = episode_batch['u'].shape
return shape[0] * shape[1]
def reshape_for_broadcasting(source, target):
"""Reshapes a tensor (source) to have the correct shape and dtype of the target
before broadcasting it with MPI.
"""
dim = len(target.get_shape())
shape = ([1] * (dim - 1)) + [-1]
return tf.reshape(tf.cast(source, target.dtype), shape)