Skip to content

Commit

Permalink
Add initial metrics code
Browse files Browse the repository at this point in the history
  • Loading branch information
Shubham Jha committed Jun 20, 2020
1 parent 974d52d commit e6b2cb3
Show file tree
Hide file tree
Showing 7 changed files with 65 additions and 23 deletions.
17 changes: 11 additions & 6 deletions rlkit/__main__.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
from rlkit.agents import RandomAgent
from rlkit.core import Metrics
from rlkit.environments.gym_environment import GymEnvironment
from rlkit.trainers import BasicTrainer

params = {

"experiment_params": {
"experiment_name": "debug_expt-0"
}

"environment_params": {
"env_name": "SpaceInvaders-v0",
},
"agent_params": {

},
"agent_params": {},
"training_params": {
"run_name": "test_run",
"train_interval": 10,
Expand All @@ -17,7 +21,8 @@
},
}

env = GymEnvironment(params["environment_params"])
agent = RandomAgent(params["agent_params"], env.get_action_space())
trainer = BasicTrainer(params["training_params"], agent, env)
metrics = Metrics()
env = GymEnvironment(params["environment_params"], metrics)
agent = RandomAgent(params["agent_params"], env.get_action_space(), metrics)
trainer = BasicTrainer(params["training_params"], agent, env, metrics)
trainer.train()
3 changes: 2 additions & 1 deletion rlkit/core/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .base_action_space import BaseActionSpace
from .base_agent import BaseAgent
from .base_environment import BaseEnvironment
from .base_trainer import BaseTrainer
from .base_trainer import BaseTrainer
from .metrics import Metrics
11 changes: 8 additions & 3 deletions rlkit/core/base_environment.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,19 @@
class BaseEnvironment:
def __init__(self):
def __init__(self, params, metrics):
if not hasattr(self, "params"):
self.params = params
if not hasattr(self, "metrics"):
self.metrics = metrics
self.to_render = False
self.done = False
self.reset()
self.global_step = 0

def close(self):
pass

def execute_action(self, action):
pass
# def execute_action(self, action):
# pass

def reset(self):
pass
Expand Down
12 changes: 9 additions & 3 deletions rlkit/core/base_trainer.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
from collections import defaultdict


class BaseTrainer:
def __init__(self, params):
def __init__(self, params, agent, environment, metrics):
self.metrics = metrics
self.agent = agent
self.environment = environment

self.global_step = 0
self.episodes = params.get("episodes", 10);
self.steps = params.get("steps", 100)

def do_step(self):
pass

def train(self):
pass

21 changes: 21 additions & 0 deletions rlkit/core/metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import logging
from collections import defaultdict


class Metrics:
def __init__(self, logger_name, log_level, comet_experiment=None):
self.metrics_dict = defaultdict(defaultdict(dict))
self.comet_experiment = comet_experiment
self.logger = logging.get_logger(logger_name)
self.logger.setLevel() # TODO
# TODO: initialize logger, take log_level as input

def log_metric(self, metric_name, metric_val, log_step, namespace="default"):
self.metrics_dict[namespace][log_step][metric_name] = metric_val
self.logger.info("STEP {}:{}:{}".format(log_step, metric_name, metric_val))
# TODO: add timestamp

if self.comet_experiment is not None:
self.comet_experiment.log_metric(metric_name, metric_val, step=log_step)

# TODO: Log metric using logger
10 changes: 6 additions & 4 deletions rlkit/environments/gym_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@


class GymEnvironment(BaseEnvironment):
def __init__(self, params):
def __init__(self, params, metrics):
self.params = params
self.metrics = metrics
self.env_name = params["env_name"]
self.env = gym.make(self.env_name)
super(GymEnvironment, self).__init__()
super(GymEnvironment, self).__init__(params, metrics)

def execute_action(self, action):
self.env.step(action)
# def execute_action(self, action): # TODO: remove from base class
# self.env.step(action)

def get_action_space(self):
return self.env.action_space
Expand Down Expand Up @@ -38,6 +39,7 @@ def render(self):

def step(self, action):
self.state, self.reward, self.done, self.info = self.env.step(action)
self.global_step += 1
return (self.state, self.reward, self.done, self.info, )


Expand Down
14 changes: 8 additions & 6 deletions rlkit/trainers/basic_trainer.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
from rlkit.core import BaseTrainer
from rlkit.core.metrics import Metrics # TODO: fix import path


class BasicTrainer(BaseTrainer):
def __init__(self, params, agent, environment):
self.agent = agent
self.environment = environment
super(BasicTrainer, self).__init__(params)
def __init__(self, params, agent, environment, metrics):
super(BasicTrainer, self).__init__(params, agent, environment, metrics)

self.train_interval = params["train_interval"]
self.run_name = params["run_name"]
Expand All @@ -14,7 +13,7 @@ def __init__(self, params, agent, environment):

def do_step(self):
action = self.agent.get_action(self.environment.state)
self.environment.step(action)
return self.environment.step(action)
# self.environment.render() # TODO: find better solution

def train(self):
Expand All @@ -24,7 +23,7 @@ def train(self):
self.environment.reset()
while step < self.steps and not self.environment.done:
print("episode: {}, step: {}".format(episode, step))
self.do_step()
state, reward, done, info = self.do_step()

# Train agent
if self.global_step > 0 and not self.global_step % self.train_interval:
Expand All @@ -33,5 +32,8 @@ def train(self):
# Increment step counts
step += 1
self.global_step += 1

self.metrics.log_metric("timestep", step, log_step=self.global_step, namespace="trainer")
self.metrics.log_metric("episode", episode, log_step=self.global_step, namespace="trainer")
finally:
self.environment.close()

0 comments on commit e6b2cb3

Please sign in to comment.