-
Notifications
You must be signed in to change notification settings - Fork 155
/
Copy pathvalue_iteration.py
111 lines (93 loc) · 3.91 KB
/
value_iteration.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import numpy as np
import logging
from rl_agents.agents.common.abstract import AbstractAgent
logger = logging.getLogger(__name__)
class ValueIterationAgent(AbstractAgent):
def __init__(self, env, config=None):
super(ValueIterationAgent, self).__init__(config)
self.finite_mdp = self.is_finite_mdp(env)
if self.finite_mdp:
self.mdp = env.mdp
elif not self.finite_mdp:
try:
self.mdp = env.unwrapped.to_finite_mdp()
except AttributeError:
raise TypeError("Environment must be of type finite_mdp.envs.finite_mdp.FiniteMDPEnv or handle a "
"conversion method called 'to_finite_mdp' to such a type.")
self.env = env
self.state_action_value = self.get_state_action_value()
@classmethod
def default_config(cls):
return dict(gamma=1.0,
iterations=100)
def act(self, state):
# If the environment is not a finite mdp, it must be converted to one and the state must be recovered.
if not self.finite_mdp:
self.mdp = self.env.unwrapped.to_finite_mdp()
state = self.mdp.state
self.state_action_value = self.get_state_action_value()
return np.argmax(self.state_action_value[state, :])
def get_state_value(self):
return self.fixed_point_iteration(
lambda v: ValueIterationAgent.best_action_value(self.bellman_expectation(v)),
np.zeros((self.mdp.transition.shape[0],)))
def get_state_action_value(self):
return self.fixed_point_iteration(
lambda q: self.bellman_expectation(ValueIterationAgent.best_action_value(q)),
np.zeros((self.mdp.transition.shape[0:2])))
@staticmethod
def best_action_value(action_values):
return action_values.max(axis=-1)
def bellman_expectation(self, value):
if self.mdp.mode == "deterministic":
next_v = value[self.mdp.transition]
elif self.mdp.mode == "stochastic":
next_v = (self.mdp.transition * value.reshape((1, 1, value.size))).sum(axis=-1)
elif self.mdp.mode == "sparse":
# P(s,a,B) * v[B]
next_values = np.take(value, self.mdp.next)
next_v = (self.mdp.transition * next_values).sum(axis=-1)
else:
raise ValueError("Unknown mode")
next_v[self.mdp.terminal] = 0
return self.mdp.reward + self.config["gamma"] * next_v
def fixed_point_iteration(self, operator, initial):
value = initial
for iteration in range(self.config["iterations"]):
logger.debug("Value Iteration: {}/{}".format(iteration, self.config["iterations"]))
next_value = operator(value)
if np.allclose(value, next_value):
break
value = next_value
return value
@staticmethod
def is_finite_mdp(env):
try:
finite_mdp = __import__("finite_mdp.envs.finite_mdp_env")
if isinstance(env.unwrapped, finite_mdp.envs.finite_mdp_env.FiniteMDPEnv):
return True
except (ModuleNotFoundError, TypeError):
return False
def plan_trajectory(self, state, horizon=10):
action_value = self.get_state_action_value()
states, actions = [], []
for _ in range(horizon):
action = np.argmax(action_value[state])
states.append(state)
actions.append(action)
state = self.mdp.next_state(state, action)
if self.mdp.terminal[state]:
states.append(state)
actions.append(None)
break
return states, actions
def record(self, state, action, reward, next_state, done, info):
pass
def reset(self):
pass
def seed(self, seed=None):
pass
def save(self, filename):
return False
def load(self, filename):
return False