-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathgoal_conditioned_ac.py
117 lines (97 loc) · 4.33 KB
/
goal_conditioned_ac.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import tensorflow as tf
from tf_agents.agents.ddpg import actor_network
from tf_agents.agents.ddpg import critic_network
from tf_agents.networks import utils
def set_goal(traj, goal):
for obs_field in ['observation', 'goal']:
assert obs_field in traj.observation.keys()
obs = traj.observation['observation']
tf.nest.assert_same_structure(obs, goal)
modified_traj = traj._replace(
observation={'observation': obs, 'goal': goal})
return modified_traj
def merge_obs_goal(observations):
obs = observations['observation']
goal = observations['goal']
assert obs.shape == goal.shape
assert len(obs.shape) == 2
modified_observations = tf.concat([obs, goal], axis=-1)
assert obs.shape[0] == modified_observations.shape[0]
assert modified_observations.shape[1] == obs.shape[1] + goal.shape[1]
return modified_observations
class GoalConditionedActorNetwork(actor_network.ActorNetwork):
def __init__(self,
input_tensor_spec,
output_tensor_spec,
**kwargs):
modified_tensor_spec = None
super(GoalConditionedActorNetwork, self).__init__(
modified_tensor_spec, output_tensor_spec,
fc_layer_params=(256, 256),
**kwargs)
self._input_tensor_spec = input_tensor_spec
def call(self, observations, step_type=(), network_state=()):
modified_observations = merge_obs_goal(observations)
return super(GoalConditionedActorNetwork, self).call(
modified_observations, step_type=step_type, network_state=network_state)
class GoalConditionedCriticNetwork(critic_network.CriticNetwork):
def __init__(self,
input_tensor_spec,
observation_conv_layer_params=None,
observation_fc_layer_params=(256,),
action_fc_layer_params=None,
joint_fc_layer_params=(256,),
activation_fn=tf.nn.relu,
name='CriticNetwork',
output_dim=None):
self._output_dim = output_dim
(_, action_spec) = input_tensor_spec
modified_obs_spec = None
modified_tensor_spec = (modified_obs_spec, action_spec)
super(critic_network.CriticNetwork, self).__init__(
input_tensor_spec=modified_tensor_spec,
state_spec=(),
name=name)
self._input_tensor_spec = input_tensor_spec
flat_action_spec = tf.nest.flatten(action_spec)
if len(flat_action_spec) > 1:
raise ValueError('Only a single action is supported by this network')
self._single_action_spec = flat_action_spec[0]
self._observation_layers = utils.mlp_layers(
observation_conv_layer_params,
observation_fc_layer_params,
activation_fn=activation_fn,
kernel_initializer=tf.compat.v1.keras.initializers.VarianceScaling(
scale=1. / 3., mode='fan_in', distribution='uniform'),
name='observation_encoding')
self._action_layers = utils.mlp_layers(
None,
action_fc_layer_params,
activation_fn=activation_fn,
kernel_initializer=tf.compat.v1.keras.initializers.VarianceScaling(
scale=1. / 3., mode='fan_in', distribution='uniform'),
name='action_encoding')
self._joint_layers = utils.mlp_layers(
None,
joint_fc_layer_params,
activation_fn=activation_fn,
kernel_initializer=tf.compat.v1.keras.initializers.VarianceScaling(
scale=1. / 3., mode='fan_in', distribution='uniform'),
name='joint_mlp')
self._joint_layers.append(
tf.keras.layers.Dense(
self._output_dim if self._output_dim is not None else 1,
activation=None,
kernel_initializer=tf.compat.v1.keras.initializers.RandomUniform(
minval=-0.003, maxval=0.003),
name='value'))
def call(self, inputs, step_type=(), network_state=()):
observations, actions = inputs
modified_observations = merge_obs_goal(observations)
modified_inputs = (modified_observations, actions)
output = super(GoalConditionedCriticNetwork, self).call(
modified_inputs, step_type=step_type, network_state=network_state)
(predictions, network_state) = output
if self._output_dim is not None:
predictions = tf.reshape(predictions, [-1, self._output_dim])
return predictions, network_state