Skip to content

Commit

Permalink
Adapt algo to 910B
Browse files Browse the repository at this point in the history
  • Loading branch information
MashiroChen committed Dec 11, 2023
1 parent f11d379 commit 9525fc7
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 33 deletions.
5 changes: 4 additions & 1 deletion example/ppo/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,9 @@
parser.add_argument(
"--worker_num", type=int, default=2, help="Worker num (Default: 2)."
)
parser.add_argument(
"--graph_op_run", type=int, default=1, help="Run kernel by kernel (Default: 1)."
)
options, _ = parser.parse_known_args()


Expand All @@ -75,7 +78,7 @@ def train(episode=options.episode):
context.set_context(device_target=options.device_target)
if context.get_context("device_target") in ["CPU"]:
context.set_context(enable_graph_kernel=True)
if context.get_context("device_target") in ["Ascend"]:
if context.get_context("device_target") in ["Ascend"] and options.graph_op_run:
os.environ["GRAPH_OP_RUN"] = "1"

compute_type = (
Expand Down
32 changes: 15 additions & 17 deletions mindspore_rl/algorithm/maddpg/maddpg_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,29 +187,27 @@ def train_one_episode(self):
duration += 1
training_reward += rew_n.sum()

# ----------------------------------------- learner -------------------------------------------
if self.train_step % 100 == 0:
(
# ----------------------------------------- learner -------------------------------------------
if self.train_step % 100 == 0:
(
obs_n_batch,
act_n_batch,
rew_n_batch,
obs_next_n_batch,
done_n_batch,
) = self.msrl.replay_buffer_sample()
agent_id = 0
while agent_id < self.num_agent:
loss += self._learn(
agent_id,
obs_n_batch,
act_n_batch,
rew_n_batch,
obs_next_n_batch,
done_n_batch,
) = self.msrl.replay_buffer_sample()
agent_id = 0
while agent_id < self.num_agent:
loss += self._learn(
agent_id,
obs_n_batch,
act_n_batch,
rew_n_batch,
obs_next_n_batch,
done_n_batch,
)
agent_id += 1
)
agent_id += 1

if dones:
break
return loss, training_reward, duration

def trainable_variables(self):
Expand Down
24 changes: 11 additions & 13 deletions mindspore_rl/algorithm/mappo/mappo.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,7 @@ class CollectPolicy(nn.Cell):
def __init__(self, actor_net):
super().__init__()
self.categorical_dist = msd.Categorical()
self.multinomial = P.Multinomial().add_prim_attr("primitive_target", "CPU")
self.actor_net = actor_net
self.expand_dims = P.ExpandDims()
self.exp = P.Exp()
Expand All @@ -248,7 +249,7 @@ def construct(self, inputs_data):
self.log(self.reduce_sum(categorical_x, -1))
norm_action_prob = self.exp(norm_log_categorical_x)

actions = self.categorical_dist.sample((), norm_action_prob)
actions = self.multinomial(norm_action_prob, 1).squeeze(-1)
log_prob = self.categorical_dist.log_prob(
actions, norm_action_prob)

Expand Down Expand Up @@ -510,19 +511,16 @@ def reshape_tensor_3d(tensor):
value = reshape_tensor_2d(value[1:])
norm_advantage = reshape_tensor_2d(norm_advantage)
discounted_r = reshape_tensor_2d(discounted_r[1:])

L, N = 10, 320
indices = ops.Randperm(N)(Tensor([N], ms.int32))
global_obs = F.gather(_reshape(global_obs, N, L), indices, 0)
local_obs = F.gather(_reshape(local_obs, N, L), indices, 0)
hn_actor = F.gather(hn_actor, indices, 0)
hn_critic = F.gather(hn_critic, indices, 0)
actions = F.gather(_reshape(actions, N, L), indices, 0)
value = F.gather(_reshape(value, N, L), indices, 0)
discounted_r = F.gather(_reshape(discounted_r, N, L), indices, 0)
mask = F.gather(_reshape(mask, N, L), indices, 0)
log_prob = F.gather(_reshape(log_prob, N, L), indices, 0)
norm_advantage = F.gather(_reshape(norm_advantage, N, L), indices, 0)

global_obs = _reshape(global_obs, N, L)
local_obs = _reshape(local_obs, N, L)
actions = _reshape(actions, N, L)
value = _reshape(value, N, L)
discounted_r = _reshape(discounted_r, N, L)
mask = _reshape(mask, N, L)
log_prob = _reshape(log_prob, N, L)
norm_advantage = _reshape(norm_advantage, N, L)

global_obs = _cast1(global_obs)
local_obs = _cast1(local_obs)
Expand Down
2 changes: 1 addition & 1 deletion mindspore_rl/algorithm/ppo/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@
"policy_and_network": {"type": PPOPolicy, "params": policy_params},
"collect_environment": {
"number": 30,
"num_parallel": 30,
"num_parallel": 5,
"type": GymEnvironment,
"wrappers": [PyFuncWrapper, SyncParallelWrapper],
"params": collect_env_params,
Expand Down
6 changes: 5 additions & 1 deletion mindspore_rl/utils/soft_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,12 @@ def _update(self, factor, behavior_param, target_param):
return target_param

def construct(self):
if not self.mod(self.steps, self.update_interval):
if self.update_interval == 1:
updater = F.partial(self._update, self.factor)
self.hyper_map(updater, self.behavior_params, self.target_params)
else:
if not self.mod(self.steps, self.update_interval):
updater = F.partial(self._update, self.factor)
self.hyper_map(updater, self.behavior_params, self.target_params)
self.steps += 1
return self.steps

0 comments on commit 9525fc7

Please sign in to comment.