-
Notifications
You must be signed in to change notification settings - Fork 1
/
symbolic_finetuning_unguided.py
91 lines (75 loc) · 2.87 KB
/
symbolic_finetuning_unguided.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import pickle
import gym
import numpy as np
import retro
from gplearn.fitness import make_fitness
from gplearn.genetic import SymbolicRegressor
from wrappers.circuscharlie import SimpleCircusCharlieWrapper # noqa
from wrappers.pong import SimplePongWrapper # noqa
from wrappers.seaquest import SimpleSeaquestWrapper
def to_discrete(raw_action, size=6):
scaled_action = size / (1 + np.exp(raw_action))
if scaled_action == size:
return int(scaled_action) - 1
else:
return int(scaled_action)
def to_continuous(raw_action, low, high):
scaled_action = (high - low) / (1 + np.exp(raw_action)) + low
return np.float32(scaled_action)
def get_rewards(program, env_, render=False, assert_torch=False):
done = False
env = env_
rewards = []
for _ in range(4):
obs = env.reset()
done = False
episode_reward = 0
while not done:
raw_action = program.execute(np.expand_dims(obs, axis=0))
if isinstance(env.action_space, gym.spaces.discrete.Discrete):
action = to_discrete(raw_action, env.action_space.n)
elif isinstance(env.action_space, gym.spaces.box.Box):
action = to_continuous(raw_action, env.action_space.low, env.action_space.high)
else:
raise NotImplementedError("Only Discrete or Box action spaces are supported currently!")
if assert_torch:
raw_action_torch = program.execute_torch(np.expand_dims(obs, axis=0))
action_torch = np.tanh(raw_action_torch.detach().numpy()).astype(np.float32)
assert np.isclose(action_torch, action, atol=1e-4)
obs, reward, done, info = env.step(action)
episode_reward += reward
if render:
env.render()
rewards.append(episode_reward)
# env.close()
return np.mean(rewards)
rewards = make_fitness(get_rewards, greater_is_better=True, skip_checks=True)
def main():
# env = gym.make("CartPole-v1")
# env = gym.make("MountainCarContinuous-v0")
# env = SimplePongWrapper(gym.make("PongNoFrameskip-v4"))
# env = SimpleCircusCharlieWrapper(retro.make("CircusCharlie-Nes"))
env = SimpleSeaquestWrapper(retro.make("Seaquest-Atari2600"))
est_gp = SymbolicRegressor(
population_size=16,
generations=100,
stopping_criteria=1000,
metric=rewards,
verbose=1,
n_jobs=64,
p_crossover=0.9,
p_constants_sgd=0,
# parsimony_coefficient=0.0001,
init_depth=(2, 8),
)
try:
est_gp.fit(env=env)
except KeyboardInterrupt:
pass
# print(est_gp._program)
# final_rewards = get_rewards(est_gp._program, env, render=False, assert_torch=False)
# print(final_rewards)
with open("circuscharlie.pkl", "wb") as f:
pickle.dump(est_gp, f)
if __name__ == "__main__":
main()