-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Shubham Jha
committed
Jun 11, 2020
1 parent
7026378
commit 974d52d
Showing
9 changed files
with
702 additions
and
8 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
# Lines starting with # are treated as comments (or with whitespaces+#). | ||
# It doesn't matter if you use capital letters or not. | ||
# It doesn't matter if you use underscore or camel notation for keys, e.g. episode_timeout is the same as episodeTimeout. | ||
|
||
doom_map = map01 | ||
|
||
# Rewards | ||
living_reward = -1 | ||
|
||
# Rendering options | ||
screen_resolution = RES_320X240 | ||
screen_format = CRCGCB | ||
render_hud = True | ||
render_crosshair = false | ||
render_weapon = true | ||
render_decals = false | ||
render_particles = false | ||
window_visible = true | ||
|
||
# make episodes start after 20 tics (after unholstering the gun) | ||
episode_start_time = 14 | ||
|
||
# make episodes finish after 300 actions (tics) | ||
episode_timeout = 300 | ||
|
||
# Available buttons | ||
available_buttons = | ||
{ | ||
MOVE_LEFT | ||
MOVE_RIGHT | ||
ATTACK | ||
} | ||
|
||
# Game variables that will be in the state | ||
available_game_variables = { AMMO2} | ||
|
||
mode = PLAYER | ||
doom_skill = 5 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
from rlkit.agents import RandomAgent | ||
from rlkit.environments.gym_environment import GymEnvironment | ||
from rlkit.environments.vizdoom_environment import VizDoomEnvironment | ||
from rlkit.trainers import BasicTrainer | ||
|
||
params = { | ||
"environment_params": { | ||
# "env_name": "SpaceInvaders-v0", | ||
}, | ||
"agent_params": { | ||
|
||
}, | ||
"training_params": { | ||
"run_name": "test_run", | ||
"train_interval": 10, | ||
"episodes": 5, | ||
"steps": 500, | ||
}, | ||
} | ||
|
||
# env = GymEnvironment(params["environment_params"]) | ||
env = VizDoomEnvironment(params["environment_params"]) | ||
agent = RandomAgent(params["agent_params"], env.get_action_space()) | ||
trainer = BasicTrainer(params["training_params"], agent, env) | ||
trainer.train() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
from .base_action_space import BaseActionSpace | ||
from .base_agent import BaseAgent | ||
from .base_environment import BaseEnvironment | ||
from .base_trainer import BaseTrainer |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
class BaseActionSpace: | ||
def __init__(self): | ||
pass | ||
|
||
def sample(self): | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,14 +1,64 @@ | ||
from rlkit.core import BaseEnvironment | ||
import random | ||
import time | ||
|
||
from rlkit.core import BaseEnvironment, BaseActionSpace | ||
from vizdoom import * | ||
|
||
class VizDoomEnvironment(BaseEnvironment): | ||
|
||
class VizDoomActionSpace(BaseActionSpace): | ||
def __init__(self): | ||
self.actions = [ | ||
# http://www.cs.put.poznan.pl/visualdoomai/tutorial.html | ||
[0, 0, 1], # shoot | ||
[1, 0, 0], # left | ||
[0, 1, 0], # right | ||
] | ||
super(VizDoomEnvironment.VizDoomActionSpace, self).__init__() | ||
|
||
def sample(self): | ||
return random.sample(self.actions, 1)[0] | ||
|
||
def __init__(self, params): | ||
self.action_space = self.VizDoomActionSpace() | ||
self.initialize_env() | ||
super(VizDoomEnvironment, self).__init__() | ||
self.env_name = params["env_name"] | ||
|
||
pass | ||
|
||
def initialize_env(self): | ||
self.env = DoomGame() | ||
self.env.load_config("../config/basic.cfg") | ||
self.env.init() | ||
self.env.load_config("./basic.cfg") # TODO: load via params | ||
self.env.init() | ||
|
||
def get_action_space(self): | ||
return self.action_space | ||
|
||
def reset(self, reset_values=True): | ||
if reset_values: | ||
self.reset_values() | ||
self.reset_env() | ||
|
||
def reset_values(self): | ||
self.state = None | ||
self.reward = None | ||
self.done = False | ||
self.info = None | ||
|
||
def reset_env(self): | ||
self.env.new_episode() | ||
|
||
def step(self, action): | ||
self.reward = self.env.make_action(action) | ||
|
||
# TODO: see if need to get image buffer | ||
# TODO: see if this happens before/after reward | ||
self.state = self.env.get_state() | ||
|
||
self.done = self.env.is_episode_finished() | ||
if not self.done: | ||
self.info = self.state.game_variables | ||
else: | ||
self.info = None | ||
|
||
print(action, self.done, self.env.get_total_reward(), self.info) | ||
time.sleep(0.02) # TODO: remove | ||
return (self.state, self.reward, self.done, self.info, ) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters