-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Shubham Jha
committed
Jun 11, 2020
1 parent
37adf3b
commit 7026378
Showing
14 changed files
with
166 additions
and
11 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,3 @@ | ||
tensorflow==1.11.0 | ||
gym==0.10.8 | ||
gym==0.17.2 | ||
numpy==1.15.4 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +0,0 @@ | ||
from .algorithms.random_agent import RandomAgent | ||
from .algorithms.dqn import DQN | ||
from .algorithms.policy_gradients import REINFORCE | ||
from .algorithms.agent import Agent | ||
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
from rlkit.agents import RandomAgent | ||
from rlkit.environments.gym_environment import GymEnvironment | ||
from rlkit.trainers import BasicTrainer | ||
|
||
params = { | ||
"environment_params": { | ||
"env_name": "SpaceInvaders-v0", | ||
}, | ||
"agent_params": { | ||
|
||
}, | ||
"training_params": { | ||
"run_name": "test_run", | ||
"train_interval": 10, | ||
"episodes": 5, | ||
"steps": 500, | ||
}, | ||
} | ||
|
||
env = GymEnvironment(params["environment_params"]) | ||
agent = RandomAgent(params["agent_params"], env.get_action_space()) | ||
trainer = BasicTrainer(params["training_params"], agent, env) | ||
trainer.train() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .random_agent import RandomAgent |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
from rlkit.core.base_agent import BaseAgent | ||
|
||
|
||
class RandomAgent(BaseAgent): | ||
def __init__(self, params, action_space): | ||
super(RandomAgent, self).__init__(params, action_space) | ||
|
||
def get_action(self, state): | ||
return self.action_space.sample() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from .base_agent import BaseAgent | ||
from .base_environment import BaseEnvironment | ||
from .base_trainer import BaseTrainer |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
class BaseAction |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
class BaseAgent: | ||
def __init__(self, params, action_space): | ||
self.params = params | ||
self.action_space = action_space | ||
|
||
def train(self): | ||
pass | ||
|
||
def get_action(self, state): | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
class BaseEnvironment: | ||
def __init__(self): | ||
self.to_render = False | ||
self.reset() | ||
|
||
def execute_action(self, action): | ||
pass | ||
|
||
def reset(self): | ||
pass | ||
|
||
def render(self): | ||
pass | ||
|
||
def setRender(self, to_render): | ||
self.to_render = to_render |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
import gym | ||
from rlkit.core import BaseEnvironment | ||
|
||
|
||
class GymEnvironment(BaseEnvironment): | ||
def __init__(self, params): | ||
self.params = params | ||
self.env_name = params["env_name"] | ||
self.env = gym.make(self.env_name) | ||
super(GymEnvironment, self).__init__() | ||
|
||
def execute_action(self, action): | ||
self.env.step(action) | ||
|
||
def get_action_space(self): | ||
return self.env.action_space | ||
|
||
def reset(self, reset_values=True): | ||
if reset_values: | ||
self.reset_values() | ||
self.reset_env() | ||
|
||
def reset_values(self): | ||
self.state = None | ||
self.reward = None | ||
self.done = False | ||
self.info = None | ||
|
||
def reset_env(self): | ||
self.env.reset() | ||
|
||
def close(self): | ||
print("closing env") | ||
return self.env.close() | ||
|
||
def render(self): | ||
self.env.render() | ||
|
||
def step(self, action): | ||
self.state, self.reward, self.done, self.info = self.env.step(action) | ||
return (self.state, self.reward, self.done, self.info, ) | ||
|
||
|
||
if __name__ == "__main__": | ||
test_env = GymEnvironment("MountainCarContinuous-v0") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
from rlkit.core import BaseEnvironment | ||
from vizdoom import * | ||
|
||
class VizDoomEnvironment(BaseEnvironment): | ||
def __init__(self, params): | ||
super(VizDoomEnvironment, self).__init__() | ||
self.env_name = params["env_name"] | ||
|
||
pass | ||
|
||
def initialize_env(self): | ||
self.env = DoomGame() | ||
self.env.load_config("../config/basic.cfg") | ||
self.env.init() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .basic_trainer import BasicTrainer |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,37 @@ | ||
class BasicTrainer: | ||
def __init__(self): | ||
pass | ||
from rlkit.core import BaseTrainer | ||
|
||
|
||
class BasicTrainer(BaseTrainer): | ||
def __init__(self, params, agent, environment): | ||
self.agent = agent | ||
self.environment = environment | ||
super(BasicTrainer, self).__init__(params) | ||
|
||
self.train_interval = params["train_interval"] | ||
self.run_name = params["run_name"] | ||
self.episodes = params["episodes"] | ||
self.steps = params["steps"] | ||
|
||
def do_step(self): | ||
action = self.agent.get_action(self.environment.state) | ||
self.environment.step(action) | ||
self.environment.render() # TODO: find better solution | ||
|
||
def train(self): | ||
try: | ||
for episode in range(1, self.episodes+1): | ||
step = 0 | ||
self.environment.reset() | ||
while step < self.steps and not self.environment.done: | ||
print("episode: {}, step: {}".format(episode, step)) | ||
self.do_step() | ||
|
||
# Train agent | ||
if self.global_step > 0 and not self.global_step % self.train_interval: | ||
self.agent.train() | ||
|
||
# Increment step counts | ||
step += 1 | ||
self.global_step += 1 | ||
finally: | ||
self.environment.close() |