Skip to content

Commit

Permalink
Adapt algo to 910B
Browse files Browse the repository at this point in the history
  • Loading branch information
MashiroChen committed Dec 11, 2023
1 parent f11d379 commit 63b4150
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 31 deletions.
32 changes: 15 additions & 17 deletions mindspore_rl/algorithm/maddpg/maddpg_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,29 +187,27 @@ def train_one_episode(self):
duration += 1
training_reward += rew_n.sum()

# ----------------------------------------- learner -------------------------------------------
if self.train_step % 100 == 0:
(
# ----------------------------------------- learner -------------------------------------------
if self.train_step % 100 == 0:
(
obs_n_batch,
act_n_batch,
rew_n_batch,
obs_next_n_batch,
done_n_batch,
) = self.msrl.replay_buffer_sample()
agent_id = 0
while agent_id < self.num_agent:
loss += self._learn(
agent_id,
obs_n_batch,
act_n_batch,
rew_n_batch,
obs_next_n_batch,
done_n_batch,
) = self.msrl.replay_buffer_sample()
agent_id = 0
while agent_id < self.num_agent:
loss += self._learn(
agent_id,
obs_n_batch,
act_n_batch,
rew_n_batch,
obs_next_n_batch,
done_n_batch,
)
agent_id += 1
)
agent_id += 1

if dones:
break
return loss, training_reward, duration

def trainable_variables(self):
Expand Down
24 changes: 11 additions & 13 deletions mindspore_rl/algorithm/mappo/mappo.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,7 @@ class CollectPolicy(nn.Cell):
def __init__(self, actor_net):
super().__init__()
self.categorical_dist = msd.Categorical()
self.multinomial = P.Multinomial().add_prim_attr("primitive_target", "CPU")
self.actor_net = actor_net
self.expand_dims = P.ExpandDims()
self.exp = P.Exp()
Expand All @@ -248,7 +249,7 @@ def construct(self, inputs_data):
self.log(self.reduce_sum(categorical_x, -1))
norm_action_prob = self.exp(norm_log_categorical_x)

actions = self.categorical_dist.sample((), norm_action_prob)
actions = self.multinomial(norm_action_prob, 1).squeeze(-1)
log_prob = self.categorical_dist.log_prob(
actions, norm_action_prob)

Expand Down Expand Up @@ -510,19 +511,16 @@ def reshape_tensor_3d(tensor):
value = reshape_tensor_2d(value[1:])
norm_advantage = reshape_tensor_2d(norm_advantage)
discounted_r = reshape_tensor_2d(discounted_r[1:])

L, N = 10, 320
indices = ops.Randperm(N)(Tensor([N], ms.int32))
global_obs = F.gather(_reshape(global_obs, N, L), indices, 0)
local_obs = F.gather(_reshape(local_obs, N, L), indices, 0)
hn_actor = F.gather(hn_actor, indices, 0)
hn_critic = F.gather(hn_critic, indices, 0)
actions = F.gather(_reshape(actions, N, L), indices, 0)
value = F.gather(_reshape(value, N, L), indices, 0)
discounted_r = F.gather(_reshape(discounted_r, N, L), indices, 0)
mask = F.gather(_reshape(mask, N, L), indices, 0)
log_prob = F.gather(_reshape(log_prob, N, L), indices, 0)
norm_advantage = F.gather(_reshape(norm_advantage, N, L), indices, 0)

global_obs = _reshape(global_obs, N, L)
local_obs = _reshape(local_obs, N, L)
actions = _reshape(actions, N, L)
value = _reshape(value, N, L)
discounted_r = _reshape(discounted_r, N, L)
mask = _reshape(mask, N, L)
log_prob = _reshape(log_prob, N, L)
norm_advantage = _reshape(norm_advantage, N, L)

global_obs = _cast1(global_obs)
local_obs = _cast1(local_obs)
Expand Down
6 changes: 5 additions & 1 deletion mindspore_rl/utils/soft_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,12 @@ def _update(self, factor, behavior_param, target_param):
return target_param

def construct(self):
if not self.mod(self.steps, self.update_interval):
if self.update_interval == 1:
updater = F.partial(self._update, self.factor)
self.hyper_map(updater, self.behavior_params, self.target_params)
else:
if not self.mod(self.steps, self.update_interval):
updater = F.partial(self._update, self.factor)
self.hyper_map(updater, self.behavior_params, self.target_params)
self.steps += 1
return self.steps

0 comments on commit 63b4150

Please sign in to comment.