Skip to content

Commit

Permalink
Add basic training iteration
Browse files Browse the repository at this point in the history
  • Loading branch information
Shubham Jha committed Jun 11, 2020
1 parent 37adf3b commit 7026378
Show file tree
Hide file tree
Showing 14 changed files with 166 additions and 11 deletions.
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
tensorflow==1.11.0
gym==0.10.8
gym==0.17.2
numpy==1.15.4
4 changes: 0 additions & 4 deletions rlkit/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +0,0 @@
from .algorithms.random_agent import RandomAgent
from .algorithms.dqn import DQN
from .algorithms.policy_gradients import REINFORCE
from .algorithms.agent import Agent
23 changes: 23 additions & 0 deletions rlkit/__main__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from rlkit.agents import RandomAgent
from rlkit.environments.gym_environment import GymEnvironment
from rlkit.trainers import BasicTrainer

params = {
"environment_params": {
"env_name": "SpaceInvaders-v0",
},
"agent_params": {

},
"training_params": {
"run_name": "test_run",
"train_interval": 10,
"episodes": 5,
"steps": 500,
},
}

env = GymEnvironment(params["environment_params"])
agent = RandomAgent(params["agent_params"], env.get_action_space())
trainer = BasicTrainer(params["training_params"], agent, env)
trainer.train()
1 change: 1 addition & 0 deletions rlkit/agents/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .random_agent import RandomAgent
9 changes: 9 additions & 0 deletions rlkit/agents/random_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from rlkit.core.base_agent import BaseAgent


class RandomAgent(BaseAgent):
def __init__(self, params, action_space):
super(RandomAgent, self).__init__(params, action_space)

def get_action(self, state):
return self.action_space.sample()
3 changes: 3 additions & 0 deletions rlkit/core/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .base_agent import BaseAgent
from .base_environment import BaseEnvironment
from .base_trainer import BaseTrainer
1 change: 1 addition & 0 deletions rlkit/core/base_action.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
class BaseAction
10 changes: 10 additions & 0 deletions rlkit/core/base_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
class BaseAgent:
def __init__(self, params, action_space):
self.params = params
self.action_space = action_space

def train(self):
pass

def get_action(self, state):
pass
16 changes: 16 additions & 0 deletions rlkit/core/base_environment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
class BaseEnvironment:
def __init__(self):
self.to_render = False
self.reset()

def execute_action(self, action):
pass

def reset(self):
pass

def render(self):
pass

def setRender(self, to_render):
self.to_render = to_render
8 changes: 5 additions & 3 deletions rlkit/core/base_trainer.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
class BaseTrainer:
def __init__(self):
pass
def __init__(self, params):
self.global_step = 0
self.episodes = params.get("episodes", 10);
self.steps = params.get("steps", 100)

def step(self):
def do_step(self):
pass

def train(self):
Expand Down
45 changes: 45 additions & 0 deletions rlkit/environments/gym_environment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import gym
from rlkit.core import BaseEnvironment


class GymEnvironment(BaseEnvironment):
def __init__(self, params):
self.params = params
self.env_name = params["env_name"]
self.env = gym.make(self.env_name)
super(GymEnvironment, self).__init__()

def execute_action(self, action):
self.env.step(action)

def get_action_space(self):
return self.env.action_space

def reset(self, reset_values=True):
if reset_values:
self.reset_values()
self.reset_env()

def reset_values(self):
self.state = None
self.reward = None
self.done = False
self.info = None

def reset_env(self):
self.env.reset()

def close(self):
print("closing env")
return self.env.close()

def render(self):
self.env.render()

def step(self, action):
self.state, self.reward, self.done, self.info = self.env.step(action)
return (self.state, self.reward, self.done, self.info, )


if __name__ == "__main__":
test_env = GymEnvironment("MountainCarContinuous-v0")
14 changes: 14 additions & 0 deletions rlkit/environments/vizdoom_environment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from rlkit.core import BaseEnvironment
from vizdoom import *

class VizDoomEnvironment(BaseEnvironment):
def __init__(self, params):
super(VizDoomEnvironment, self).__init__()
self.env_name = params["env_name"]

pass

def initialize_env(self):
self.env = DoomGame()
self.env.load_config("../config/basic.cfg")
self.env.init()
1 change: 1 addition & 0 deletions rlkit/trainers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .basic_trainer import BasicTrainer
40 changes: 37 additions & 3 deletions rlkit/trainers/basic_trainer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,37 @@
class BasicTrainer:
def __init__(self):
pass
from rlkit.core import BaseTrainer


class BasicTrainer(BaseTrainer):
def __init__(self, params, agent, environment):
self.agent = agent
self.environment = environment
super(BasicTrainer, self).__init__(params)

self.train_interval = params["train_interval"]
self.run_name = params["run_name"]
self.episodes = params["episodes"]
self.steps = params["steps"]

def do_step(self):
action = self.agent.get_action(self.environment.state)
self.environment.step(action)
self.environment.render() # TODO: find better solution

def train(self):
try:
for episode in range(1, self.episodes+1):
step = 0
self.environment.reset()
while step < self.steps and not self.environment.done:
print("episode: {}, step: {}".format(episode, step))
self.do_step()

# Train agent
if self.global_step > 0 and not self.global_step % self.train_interval:
self.agent.train()

# Increment step counts
step += 1
self.global_step += 1
finally:
self.environment.close()

0 comments on commit 7026378

Please sign in to comment.