-
Notifications
You must be signed in to change notification settings - Fork 1
/
run.py
126 lines (101 loc) · 4.07 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import gym, ray
from gym import spaces
import numpy as np
from scipy.spatial import distance
import pdb
import MultiAgentEnv as ma_env
from policy import PolicyNetwork
from ray.rllib.utils.annotations import override
from ray.rllib.models import ModelCatalog
from ray.tune.registry import register_env
from ray import tune
import ray.rllib.agents.ppo as ppo
import os
from ray.tune.logger import pretty_print
from ray.tune.logger import Logger
from typing import Dict
from ray import tune
from ray.rllib.agents.callbacks import DefaultCallbacks
from ray.rllib.env import BaseEnv
from ray.rllib.evaluation import MultiAgentEpisode, RolloutWorker
from ray.rllib.policy import Policy
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.schedules import ConstantSchedule, LinearSchedule, ExponentialSchedule, PiecewiseSchedule
from datetime import datetime
LOG_FILE = "logs/discrete_reward_strict{}.txt".format(datetime.now().strftime("%d_%m_%H_%M"))
class MyCallbacks(DefaultCallbacks):
def on_episode_end(self, *, worker: RolloutWorker, base_env: BaseEnv,
policies: Dict[str, Policy], episode: MultiAgentEpisode,
env_index: int, **kwargs):
trajectories = base_env.get_unwrapped()[0].trajectory
energies = base_env.get_unwrapped()[0].energies
with open(LOG_FILE, "a") as outFile:
for idx,trajectory in enumerate(trajectories):
outFile.write("episode_id: {} \t env-idx: {} \t pos:{} \t energy:{}\n".format(episode.episode_id, env_index, str(list(trajectory.flatten())),energies[idx]))
# create NN model for each atom type
model_A = PolicyNetwork
ModelCatalog.register_custom_model("modelA", model_A)
# define action space and observation space
# action space is step the policy takes in angstrom
# observation space are the coordinates of the single atom
act_space = spaces.Box(low=-0.05,high=0.05, shape=(3,))
obs_space = spaces.Box(low=-1000,high=1000, shape=(128+5+3,))
def gen_policy(atom):
model = "model{}".format(atom)
config = {"model": {"custom_model": model,},}
return (None, obs_space, act_space, config)
policies = {"policy_A": gen_policy("A")}
policy_ids = list(policies.keys())
def policy_mapping_fn(agent_id, episode, **kwargs):
pol_id = "policy_A"
return pol_id
def env_creator(env_config):
return ma_env.MA_env(env_config) # return an env instance
register_env("MA_env", env_creator)
config = ppo.DEFAULT_CONFIG.copy()
config["multiagent"] = {
"policy_mapping_fn": policy_mapping_fn,
"policies": policies,
"policies_to_train": ["policy_A"],#, "policy_N", "policy_O", "policy_H"],
"count_steps_by": "env_steps"
}
config["entropy_coeff"] = 0.0001
config["kl_coeff"] = 1.0
config["kl_target"] = 0.01
config["gamma"] = 0.90
config["lambda"] = 1.00 #1
config["clip_param"] = 0.3
config["vf_clip_param"] = 10 #10 or 40
config["log_level"] = "INFO"
config["framework"] = "torch"
config["num_gpus"] = 1
# config["num_gpus_per_worker"] = 0.
# config["env_config"] = {"atoms":["C", "N", "O", "H"]}
config["env_config"] = {"atoms":["C", "H"]}
config["rollout_fragment_length"] = 200 # train_batch_size / rollout_fragment_length = num_fragments
config["sgd_minibatch_size"] = 512
config["train_batch_size"] = 2048
config["num_workers"] = 2
# config["num_envs_per_worker"] = 1
# config["remote_worker_envs"] = False
config["ignore_worker_failures"] = False
config["horizon"] = 20
config["soft_horizon"] = False
config["batch_mode"] = "truncate_episodes"
config["vf_share_layers"] = True
config["lr"] = 5e-05
print(pretty_print(config))
ray.init()
agent = ppo.PPOTrainer(config, env="MA_env")
## Use this to restart training
# model_restore = "/home/rohit.modee/ray_results/PPO_MA_env_2022-05-25_12-49-2717s7j3br/"
# agent.restore(model_restore + "checkpoint_000031/checkpoint-31")
n_iter = 2800
for n in range(n_iter):
result = agent.train()
print(pretty_print(result))
if n % 5 == 0:
checkpoint = agent.save()
print("checkpoint saved at", checkpoint)
checkpoint = agent.save()
print("checkpoint saved at", checkpoint)