diff --git a/rlkit/__main__.py b/rlkit/__main__.py index 272df3f..1b6a1fd 100644 --- a/rlkit/__main__.py +++ b/rlkit/__main__.py @@ -1,14 +1,18 @@ from rlkit.agents import RandomAgent +from rlkit.core import Metrics from rlkit.environments.gym_environment import GymEnvironment from rlkit.trainers import BasicTrainer params = { + + "experiment_params": { + "experiment_name": "debug_expt-0" + } + "environment_params": { "env_name": "SpaceInvaders-v0", }, - "agent_params": { - - }, + "agent_params": {}, "training_params": { "run_name": "test_run", "train_interval": 10, @@ -17,7 +21,8 @@ }, } -env = GymEnvironment(params["environment_params"]) -agent = RandomAgent(params["agent_params"], env.get_action_space()) -trainer = BasicTrainer(params["training_params"], agent, env) +metrics = Metrics() +env = GymEnvironment(params["environment_params"], metrics) +agent = RandomAgent(params["agent_params"], env.get_action_space(), metrics) +trainer = BasicTrainer(params["training_params"], agent, env, metrics) trainer.train() diff --git a/rlkit/core/__init__.py b/rlkit/core/__init__.py index dbe1e9b..776f65b 100644 --- a/rlkit/core/__init__.py +++ b/rlkit/core/__init__.py @@ -1,4 +1,5 @@ from .base_action_space import BaseActionSpace from .base_agent import BaseAgent from .base_environment import BaseEnvironment -from .base_trainer import BaseTrainer \ No newline at end of file +from .base_trainer import BaseTrainer +from .metrics import Metrics \ No newline at end of file diff --git a/rlkit/core/base_environment.py b/rlkit/core/base_environment.py index 313e10e..cbe93a6 100644 --- a/rlkit/core/base_environment.py +++ b/rlkit/core/base_environment.py @@ -1,14 +1,19 @@ class BaseEnvironment: - def __init__(self): + def __init__(self, params, metrics): + if not hasattr(self, "params"): + self.params = params + if not hasattr(self, "metrics"): + self.metrics = metrics self.to_render = False self.done = False self.reset() + self.global_step = 0 def close(self): pass - def execute_action(self, action): - pass + # def execute_action(self, action): + # pass def reset(self): pass diff --git a/rlkit/core/base_trainer.py b/rlkit/core/base_trainer.py index f88eaca..724f50a 100644 --- a/rlkit/core/base_trainer.py +++ b/rlkit/core/base_trainer.py @@ -1,11 +1,17 @@ +from collections import defaultdict + + class BaseTrainer: - def __init__(self, params): + def __init__(self, params, agent, environment, metrics): + self.metrics = metrics + self.agent = agent + self.environment = environment + self.global_step = 0 - self.episodes = params.get("episodes", 10); - self.steps = params.get("steps", 100) def do_step(self): pass def train(self): pass + diff --git a/rlkit/core/metrics.py b/rlkit/core/metrics.py new file mode 100644 index 0000000..3d58599 --- /dev/null +++ b/rlkit/core/metrics.py @@ -0,0 +1,21 @@ +import logging +from collections import defaultdict + + +class Metrics: + def __init__(self, logger_name, log_level, comet_experiment=None): + self.metrics_dict = defaultdict(defaultdict(dict)) + self.comet_experiment = comet_experiment + self.logger = logging.get_logger(logger_name) + self.logger.setLevel() # TODO + # TODO: initialize logger, take log_level as input + + def log_metric(self, metric_name, metric_val, log_step, namespace="default"): + self.metrics_dict[namespace][log_step][metric_name] = metric_val + self.logger.info("STEP {}:{}:{}".format(log_step, metric_name, metric_val)) + # TODO: add timestamp + + if self.comet_experiment is not None: + self.comet_experiment.log_metric(metric_name, metric_val, step=log_step) + + # TODO: Log metric using logger diff --git a/rlkit/environments/gym_environment.py b/rlkit/environments/gym_environment.py index 1b44c21..992f5d9 100644 --- a/rlkit/environments/gym_environment.py +++ b/rlkit/environments/gym_environment.py @@ -3,14 +3,15 @@ class GymEnvironment(BaseEnvironment): - def __init__(self, params): + def __init__(self, params, metrics): self.params = params + self.metrics = metrics self.env_name = params["env_name"] self.env = gym.make(self.env_name) - super(GymEnvironment, self).__init__() + super(GymEnvironment, self).__init__(params, metrics) - def execute_action(self, action): - self.env.step(action) + # def execute_action(self, action): # TODO: remove from base class + # self.env.step(action) def get_action_space(self): return self.env.action_space @@ -38,6 +39,7 @@ def render(self): def step(self, action): self.state, self.reward, self.done, self.info = self.env.step(action) + self.global_step += 1 return (self.state, self.reward, self.done, self.info, ) diff --git a/rlkit/trainers/basic_trainer.py b/rlkit/trainers/basic_trainer.py index 02bdac0..821b8b7 100644 --- a/rlkit/trainers/basic_trainer.py +++ b/rlkit/trainers/basic_trainer.py @@ -1,11 +1,10 @@ from rlkit.core import BaseTrainer +from rlkit.core.metrics import Metrics # TODO: fix import path class BasicTrainer(BaseTrainer): - def __init__(self, params, agent, environment): - self.agent = agent - self.environment = environment - super(BasicTrainer, self).__init__(params) + def __init__(self, params, agent, environment, metrics): + super(BasicTrainer, self).__init__(params, agent, environment, metrics) self.train_interval = params["train_interval"] self.run_name = params["run_name"] @@ -14,7 +13,7 @@ def __init__(self, params, agent, environment): def do_step(self): action = self.agent.get_action(self.environment.state) - self.environment.step(action) + return self.environment.step(action) # self.environment.render() # TODO: find better solution def train(self): @@ -24,7 +23,7 @@ def train(self): self.environment.reset() while step < self.steps and not self.environment.done: print("episode: {}, step: {}".format(episode, step)) - self.do_step() + state, reward, done, info = self.do_step() # Train agent if self.global_step > 0 and not self.global_step % self.train_interval: @@ -33,5 +32,8 @@ def train(self): # Increment step counts step += 1 self.global_step += 1 + + self.metrics.log_metric("timestep", step, log_step=self.global_step, namespace="trainer") + self.metrics.log_metric("episode", episode, log_step=self.global_step, namespace="trainer") finally: self.environment.close()