From 24cd6200b1a037f5a0806f9001c4cd48069a056d Mon Sep 17 00:00:00 2001 From: Xingdong Zuo Date: Fri, 2 Nov 2018 10:11:49 +0100 Subject: [PATCH] Update Former-commit-id: b73f9f094298c17cf8dbf4d921d666ed3df6b451 [formerly d0e5b23b985d3ede488cc0866ab16796bea317ba] Former-commit-id: 1e1166290a20adfe31d678533f7022141ecc7af0 --- docs/source/history.rst | 17 ++ examples/policy_gradient/a2c/algo.py | 113 ++------ examples/policy_gradient/a2c/engine.py | 74 +++--- examples/policy_gradient/a2c/experiment.py | 27 +- examples/policy_gradient/a2c/main.ipynb | 251 ++++++++++++++++++ examples/policy_gradient/a2c/main.py | 3 +- examples/policy_gradient/a2c/model.py | 231 ++++++++++++++++ examples/policy_gradient/a2c/policy.py | 61 ----- examples/policy_gradient/a2c_agent.py | 152 ----------- examples/policy_gradient/reinforce/algo.py | 1 - .../policy_gradient/reinforce/experiment.py | 8 +- examples/policy_gradient/vpg/algo.py | 97 ++----- examples/policy_gradient/vpg/engine.py | 76 +++--- examples/policy_gradient/vpg/experiment.py | 15 +- examples/policy_gradient/vpg/main.ipynb | 227 +++------------- examples/policy_gradient/vpg/main.py | 3 +- examples/policy_gradient/vpg/model.py | 201 ++++++++++---- lagom/history/metrics/__init__.py | 8 + lagom/history/metrics/bootstrapped_returns.py | 79 ++++++ lagom/history/metrics/final_states.py | 37 +++ lagom/history/metrics/terminal_states.py | 47 ++++ lagom/runner/segment_runner.py | 1 + lagom/runner/trajectory_runner.py | 1 + test/test_history.py | 144 ++++++++++ 24 files changed, 1116 insertions(+), 758 deletions(-) create mode 100644 examples/policy_gradient/a2c/model.py delete mode 100644 examples/policy_gradient/a2c/policy.py delete mode 100644 examples/policy_gradient/a2c_agent.py create mode 100644 lagom/history/metrics/__init__.py create mode 100644 lagom/history/metrics/bootstrapped_returns.py create mode 100644 lagom/history/metrics/final_states.py create mode 100644 lagom/history/metrics/terminal_states.py diff --git a/docs/source/history.rst b/docs/source/history.rst index 0fa2aa4d..c2b1462d 100644 --- a/docs/source/history.rst +++ b/docs/source/history.rst @@ -15,3 +15,20 @@ lagom.history: History .. autoclass:: Segment :members: + +Metrics +---------------- + +.. currentmodule:: lagom.history.metrics + +.. autofunction:: terminal_state_from_trajectory + +.. autofunction:: terminal_state_from_segment + +.. autofunction:: final_state_from_trajectory + +.. autofunction:: final_state_from_segment + +.. autofunction:: bootstrapped_returns_from_trajectory + +.. autofunction:: bootstrapped_returns_from_segment diff --git a/examples/policy_gradient/a2c/algo.py b/examples/policy_gradient/a2c/algo.py index 2ed33a3a..b48aea03 100644 --- a/examples/policy_gradient/a2c/algo.py +++ b/examples/policy_gradient/a2c/algo.py @@ -3,55 +3,42 @@ from itertools import count import numpy as np - import torch -import torch.nn as nn -import torch.nn.functional as F -import torch.optim as optim -from lagom import set_global_seeds +from lagom import Logger +from lagom.utils import pickle_dump +from lagom.utils import set_global_seeds + from lagom import BaseAlgorithm -from lagom import pickle_dump from lagom.envs import make_gym_env from lagom.envs import make_vec_env from lagom.envs import EnvSpec from lagom.envs.vec_env import SerialVecEnv -from lagom.envs.vec_env import ParallelVecEnv from lagom.envs.vec_env import VecStandardize -from lagom.core.policies import CategoricalPolicy -from lagom.core.policies import GaussianPolicy - from lagom.runner import TrajectoryRunner from lagom.runner import SegmentRunner -from lagom.agents import A2CAgent - +from model import Agent from engine import Engine -from policy import Network -from policy import LSTM class Algorithm(BaseAlgorithm): - def __call__(self, config, seed, device_str): + def __call__(self, config, seed, device): set_global_seeds(seed) - device = torch.device(device_str) logdir = Path(config['log.dir']) / str(config['ID']) / str(seed) - - # Environment related + env = make_vec_env(vec_env_class=SerialVecEnv, make_env=make_gym_env, env_id=config['env.id'], num_env=config['train.N'], # batched environment - init_seed=seed, - rolling=True) + init_seed=seed) eval_env = make_vec_env(vec_env_class=SerialVecEnv, make_env=make_gym_env, env_id=config['env.id'], num_env=config['eval.N'], - init_seed=seed, - rolling=False) + init_seed=seed) if config['env.standardize']: # running averages of observation and reward env = VecStandardize(venv=env, use_obs=True, @@ -60,102 +47,37 @@ def __call__(self, config, seed, device_str): clip_reward=10., gamma=0.99, eps=1e-8) - eval_env = VecStandardize(venv=eval_env, # remember to synchronize running averages during evaluation !!! + eval_env = VecStandardize(venv=eval_env, use_obs=True, use_reward=False, # do not process rewards, no training clip_obs=env.clip_obs, clip_reward=env.clip_reward, gamma=env.gamma, eps=env.eps, - constant_obs_mean=env.obs_runningavg.mu, # use current running average as constant + constant_obs_mean=env.obs_runningavg.mu, constant_obs_std=env.obs_runningavg.sigma) env_spec = EnvSpec(env) + + agent = Agent(config, env_spec, device) - # Network and policy - if config['network.recurrent']: - network = LSTM(config=config, device=device, env_spec=env_spec) - else: - network = Network(config=config, device=device, env_spec=env_spec) - if env_spec.control_type == 'Discrete': - policy = CategoricalPolicy(config=config, - network=network, - env_spec=env_spec, - device=device, - learn_V=True) - elif env_spec.control_type == 'Continuous': - policy = GaussianPolicy(config=config, - network=network, - env_spec=env_spec, - device=device, - learn_V=True, - min_std=config['agent.min_std'], - std_style=config['agent.std_style'], - constant_std=config['agent.constant_std'], - std_state_dependent=config['agent.std_state_dependent'], - init_std=config['agent.init_std']) - - # Optimizer and learning rate scheduler - optimizer = optim.Adam(policy.network.parameters(), lr=config['algo.lr']) - if config['algo.use_lr_scheduler']: - if 'train.iter' in config: # iteration-based - max_epoch = config['train.iter'] - elif 'train.timestep' in config: # timestep-based - max_epoch = config['train.timestep'] + 1 # avoid zero lr in final iteration - lambda_f = lambda epoch: 1 - epoch/max_epoch - lr_scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_f) - - # Agent - kwargs = {'device': device} - if config['algo.use_lr_scheduler']: - kwargs['lr_scheduler'] = lr_scheduler - agent = A2CAgent(config=config, - policy=policy, - optimizer=optimizer, - **kwargs) - - # Runner - runner = SegmentRunner(agent=agent, - env=env, - gamma=config['algo.gamma']) - eval_runner = TrajectoryRunner(agent=agent, - env=eval_env, - gamma=1.0) + runner = SegmentRunner(config, agent, env) + eval_runner = TrajectoryRunner(config, agent, eval_env) - # Engine - engine = Engine(agent=agent, - runner=runner, - config=config, - eval_runner=eval_runner) + engine = Engine(agent, runner, config, eval_runner=eval_runner) - # Training and evaluation train_logs = [] eval_logs = [] - - if config['network.recurrent']: - rnn_states_buffer = agent.policy.rnn_states # for SegmentRunner - for i in count(): if 'train.iter' in config and i >= config['train.iter']: # enough iterations break elif 'train.timestep' in config and agent.total_T >= config['train.timestep']: # enough timesteps break - if config['network.recurrent']: - if isinstance(rnn_states_buffer, list): # LSTM: [h, c] - rnn_states_buffer = [buf.detach() for buf in rnn_states_buffer] - else: - rnn_states_buffer = rnn_states_buffer.detach() - agent.policy.rnn_states = rnn_states_buffer - - train_output = engine.train(n=i) + train_output = engine.train(i) - # Logging if i == 0 or (i+1) % config['log.record_interval'] == 0 or (i+1) % config['log.print_interval'] == 0: train_log = engine.log_train(train_output) - if config['network.recurrent']: - rnn_states_buffer = agent.policy.rnn_states # for SegmentRunner - with torch.no_grad(): # disable grad, save memory eval_output = engine.eval(n=i) eval_log = engine.log_eval(eval_output) @@ -164,7 +86,6 @@ def __call__(self, config, seed, device_str): train_logs.append(train_log) eval_logs.append(eval_log) - # Save all loggings pickle_dump(obj=train_logs, f=logdir/'train_logs', ext='.pkl') pickle_dump(obj=eval_logs, f=logdir/'eval_logs', ext='.pkl') diff --git a/examples/policy_gradient/a2c/engine.py b/examples/policy_gradient/a2c/engine.py index 18bfbd39..ff33b817 100644 --- a/examples/policy_gradient/a2c/engine.py +++ b/examples/policy_gradient/a2c/engine.py @@ -3,7 +3,7 @@ import torch from lagom import Logger -from lagom import color_str +from lagom.utils import color_str from lagom.engine import BaseEngine @@ -15,12 +15,10 @@ class Engine(BaseEngine): def train(self, n): - self.agent.policy.network.train() # train mode + self.agent.train() - # Collect a list of Segment D = self.runner(T=self.config['train.T']) - # Train agent with collected data out_agent = self.agent.learn(D) train_output = {} @@ -31,35 +29,31 @@ def train(self, n): return train_output def log_train(self, train_output, **kwargs): - # Unpack D = train_output['D'] out_agent = train_output['out_agent'] n = train_output['n'] - # Loggings - logger = Logger(name='train_logger') - logger.log('train_iteration', n+1) # starts from 1 - if self.config['algo.use_lr_scheduler']: - logger.log('current_lr', out_agent['current_lr']) - - logger.log('loss', out_agent['loss']) - logger.log('policy_loss', out_agent['policy_loss']) - logger.log('policy_entropy', -out_agent['entropy_loss']) # entropy: negative entropy loss - logger.log('value_loss', out_agent['value_loss']) + logger = Logger() + logger('train_iteration', n+1) # starts from 1 + if 'current_lr' in out_agent: + logger('current_lr', out_agent['current_lr']) + logger('loss', out_agent['loss']) + logger('policy_loss', out_agent['policy_loss']) + logger('policy_entropy', -out_agent['entropy_loss']) + logger('value_loss', out_agent['value_loss']) all_immediate_reward = [segment.all_r for segment in D] num_timesteps = sum([segment.T for segment in D]) - logger.log('num_segments', len(D)) - logger.log('num_subsegments', sum([len(segment.trajectories) for segment in D])) - logger.log('num_timesteps', num_timesteps) - logger.log('accumulated_trained_timesteps', self.agent.total_T) - logger.log('average_immediate_reward', np.mean(all_immediate_reward)) - logger.log('std_immediate_reward', np.std(all_immediate_reward)) - logger.log('min_immediate_reward', np.min(all_immediate_reward)) - logger.log('max_immediate_reward', np.max(all_immediate_reward)) + logger('num_segments', len(D)) + logger('num_subsegments', sum([len(segment.trajectories) for segment in D])) + logger('num_timesteps', num_timesteps) + logger('accumulated_trained_timesteps', self.agent.total_T) + logger('average_immediate_reward', np.mean(all_immediate_reward)) + logger('std_immediate_reward', np.std(all_immediate_reward)) + logger('min_immediate_reward', np.min(all_immediate_reward)) + logger('max_immediate_reward', np.max(all_immediate_reward)) - # Dump loggings if n == 0 or (n+1) % self.config['log.print_interval'] == 0: print('-'*50) logger.dump(keys=None, index=None, indent=0) @@ -68,16 +62,15 @@ def log_train(self, train_output, **kwargs): return logger.logs def eval(self, n): - self.agent.policy.network.eval() # evaluation mode + self.agent.eval() # Synchronize running average of observations for evaluation if self.config['env.standardize']: self.eval_runner.env.constant_obs_mean = self.runner.env.obs_runningavg.mu self.eval_runner.env.constant_obs_std = self.runner.env.obs_runningavg.sigma - # Collect a list of Trajectory T = self.eval_runner.env.T - D = self.eval_runner(T=T) + D = self.eval_runner(T) eval_output = {} eval_output['D'] = D @@ -87,29 +80,26 @@ def eval(self, n): return eval_output def log_eval(self, eval_output, **kwargs): - # Unpack D = eval_output['D'] n = eval_output['n'] T = eval_output['T'] - # Loggings - logger = Logger(name='eval_logger') + logger = Logger() batch_returns = [sum(trajectory.all_r) for trajectory in D] batch_T = [trajectory.T for trajectory in D] - logger.log('evaluation_iteration', n+1) - logger.log('num_trajectories', len(D)) - logger.log('max_allowed_horizon', T) - logger.log('average_horizon', np.mean(batch_T)) - logger.log('num_timesteps', np.sum(batch_T)) - logger.log('accumulated_trained_timesteps', self.agent.total_T) - logger.log('average_return', np.mean(batch_returns)) - logger.log('std_return', np.std(batch_returns)) - logger.log('min_return', np.min(batch_returns)) - logger.log('max_return', np.max(batch_returns)) - - # Dump loggings + logger('evaluation_iteration', n+1) + logger('num_trajectories', len(D)) + logger('max_allowed_horizon', T) + logger('average_horizon', np.mean(batch_T)) + logger('num_timesteps', np.sum(batch_T)) + logger('accumulated_trained_timesteps', self.agent.total_T) + logger('average_return', np.mean(batch_returns)) + logger('std_return', np.std(batch_returns)) + logger('min_return', np.min(batch_returns)) + logger('max_return', np.max(batch_returns)) + if n == 0 or (n+1) % self.config['log.print_interval'] == 0: print(color_str('+'*50, 'yellow', 'bold')) logger.dump(keys=None, index=None, indent=0) diff --git a/examples/policy_gradient/a2c/experiment.py b/examples/policy_gradient/a2c/experiment.py index a2cba83f..99219efa 100644 --- a/examples/policy_gradient/a2c/experiment.py +++ b/examples/policy_gradient/a2c/experiment.py @@ -6,8 +6,11 @@ class ExperimentWorker(BaseExperimentWorker): + def prepare(self): + pass + def make_algo(self): - algo = Algorithm(name='A2C') + algo = Algorithm() return algo @@ -16,16 +19,16 @@ class ExperimentMaster(BaseExperimentMaster): def make_configs(self): configurator = Configurator('grid') - configurator.fixed('cuda', False) # whether to use GPU + configurator.fixed('cuda', True) # whether to use GPU configurator.fixed('env.id', 'HalfCheetah-v2') configurator.fixed('env.standardize', True) # whether to use VecStandardize - configurator.fixed('network.recurrent', True) - configurator.fixed('network.hidden_sizes', [32]) # TODO: [64, 64] + configurator.fixed('network.recurrent', False) + configurator.fixed('network.hidden_sizes', [64, 64]) # TODO: [64, 64] configurator.fixed('algo.lr', 1e-3) - configurator.fixed('algo.use_lr_scheduler', False) + configurator.fixed('algo.use_lr_scheduler', True) configurator.fixed('algo.gamma', 0.99) configurator.fixed('agent.standardize_Q', False) # whether to standardize discounted returns @@ -38,15 +41,15 @@ def make_configs(self): configurator.fixed('agent.std_style', 'exp') # std parameterization, 'exp' or 'softplus' configurator.fixed('agent.constant_std', None) # constant std, set None to learn it configurator.fixed('agent.std_state_dependent', False) # whether to learn std with state dependency - configurator.fixed('agent.init_std', 0.5) # initial std for state-independent std + configurator.fixed('agent.init_std', 1.0) # initial std for state-independent std - configurator.fixed('train.timestep', 1e7) # either 'train.iter' or 'train.timestep' - configurator.fixed('train.N', 10) # number of segments per training iteration + configurator.fixed('train.timestep', 1e6) # either 'train.iter' or 'train.timestep' + configurator.fixed('train.N', 1) # number of segments per training iteration configurator.fixed('train.T', 5) # fixed-length segment rolling - configurator.fixed('eval.N', 100) # number of episodes to evaluate, do not specify T for complete episode + configurator.fixed('eval.N', 10) # number of episodes to evaluate, do not specify T for complete episode configurator.fixed('log.record_interval', 100) # interval to record the logging - configurator.fixed('log.print_interval', 500) # interval to print the logging to screen + configurator.fixed('log.print_interval', 100) # interval to print the logging to screen configurator.fixed('log.dir', 'logs') # logging directory list_config = configurator.make_configs() @@ -58,5 +61,5 @@ def make_seeds(self): return list_seed - def process_algo_result(self, config, seed, result): - assert result is None + def process_results(self, results): + assert all([result is None for result in results]) diff --git a/examples/policy_gradient/a2c/main.ipynb b/examples/policy_gradient/a2c/main.ipynb index c1442419..7732322a 100644 --- a/examples/policy_gradient/a2c/main.ipynb +++ b/examples/policy_gradient/a2c/main.ipynb @@ -1,5 +1,256 @@ { "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 44, + "metadata": {}, + "outputs": [], + "source": [ + "from model import Agent\n", + "from lagom.envs import make_gym_env, make_vec_env\n", + "from lagom.envs import EnvSpec\n", + "from lagom.envs.vec_env import SerialVecEnv\n", + "\n", + "from model import Policy\n", + "\n", + "from lagom.runner import SegmentRunner\n", + "\n", + "env = make_vec_env(SerialVecEnv, make_gym_env, 'CartPole-v1', 3, 0)\n", + "env_spec = EnvSpec(env)\n", + "\n", + "config = {'network.hidden_sizes': [64, 64], 'algo.lr': 1e-3, 'algo.use_lr_scheduler': True}\n", + "\n", + "policy = Policy(config, env_spec, None)\n", + "\n", + "agent = Agent(config, env_spec, None)\n", + "agent\n", + "\n", + "runner = SegmentRunner(config, agent, env)\n", + "\n", + "D = runner(50)" + ] + }, + { + "cell_type": "code", + "execution_count": 45, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Segment: \n", + "\tTransition: (s=[-0.04002427 0.00464987 -0.01704236 -0.03673052], a=1, r=1.0, s_next=[-0.03993127 0.20001201 -0.01777697 -0.33474139], done=False, info={'V': tensor([0.0065], grad_fn=), 'action_logprob': tensor(-0.6929, grad_fn=), 'entropy': tensor(0.6931, grad_fn=)})\n", + "\tTransition: (s=[-0.03993127 0.20001201 -0.01777697 -0.33474139], a=1, r=1.0, s_next=[-0.03593103 0.39538239 -0.0244718 -0.63297681], done=False, info={'V': tensor([0.1759], grad_fn=), 'action_logprob': tensor(-0.6924, grad_fn=), 'entropy': tensor(0.6931, grad_fn=)})\n", + "\tTransition: (s=[-0.03593103 0.39538239 -0.0244718 -0.63297681], a=0, r=1.0, s_next=[-0.02802339 0.20061023 -0.03713133 -0.34810004], done=False, info={'V': tensor([0.3214], grad_fn=), 'action_logprob': tensor(-0.6943, grad_fn=), 'entropy': tensor(0.6931, grad_fn=)})\n", + "\tTransition: (s=[-0.02802339 0.20061023 -0.03713133 -0.34810004], a=0, r=1.0, s_next=[-0.02401118 0.00603554 -0.04409334 -0.06735328], done=False, info={'V': tensor([0.1771], grad_fn=), 'action_logprob': tensor(-0.6940, grad_fn=), 'entropy': tensor(0.6931, grad_fn=)})\n", + "\tTransition: (s=[-0.02401118 0.00603554 -0.04409334 -0.06735328], a=1, r=1.0, s_next=[-0.02389047 0.20176102 -0.0454404 -0.37361538], done=False, info={'V': tensor([0.0110], grad_fn=), 'action_logprob': tensor(-0.6928, grad_fn=), 'entropy': tensor(0.6931, grad_fn=)})\n", + "\tTransition: (s=[-0.02389047 0.20176102 -0.0454404 -0.37361538], a=1, r=1.0, s_next=[-0.01985525 0.39749801 -0.05291271 -0.68027269], done=False, info={'V': tensor([0.1801], grad_fn=), 'action_logprob': tensor(-0.6923, grad_fn=), 'entropy': tensor(0.6931, grad_fn=)})\n", + "\tTransition: (s=[-0.01985525 0.39749801 -0.05291271 -0.68027269], a=1, r=1.0, s_next=[-0.01190529 0.59331344 -0.06651816 -0.98913383], done=False, info={'V': tensor([0.3225], grad_fn=), 'action_logprob': tensor(-0.6919, grad_fn=), 'entropy': tensor(0.6931, grad_fn=)})\n", + "\tTransition: (s=[-0.01190529 0.59331344 -0.06651816 -0.98913383], a=0, r=1.0, s_next=[-3.90204852e-05 3.99141970e-01 -8.63008390e-02 -7.18062662e-01], done=False, info={'V': tensor([0.4298], grad_fn=), 'action_logprob': tensor(-0.6946, grad_fn=), 'entropy': tensor(0.6931, grad_fn=)})\n", + "\tTransition: (s=[-3.90204852e-05 3.99141970e-01 -8.63008390e-02 -7.18062662e-01], a=0, r=1.0, s_next=[ 0.00794382 0.20531357 -0.10066209 -0.45374306], done=False, info={'V': tensor([0.3221], grad_fn=), 'action_logprob': tensor(-0.6945, grad_fn=), 'entropy': tensor(0.6931, grad_fn=)})\n", + "\tTransition: (s=[ 0.00794382 0.20531357 -0.10066209 -0.45374306], a=1, r=1.0, s_next=[ 0.01205009 0.40170413 -0.10973695 -0.77638236], done=False, info={'V': tensor([0.1874], grad_fn=), 'action_logprob': tensor(-0.6920, grad_fn=), 'entropy': tensor(0.6931, grad_fn=)})\n", + "\tTransition: (s=[ 0.01205009 0.40170413 -0.10973695 -0.77638236], a=0, r=1.0, s_next=[ 0.02008417 0.20824874 -0.1252646 -0.52014269], done=False, info={'V': tensor([0.3224], grad_fn=), 'action_logprob': tensor(-0.6947, grad_fn=), 'entropy': tensor(0.6931, grad_fn=)})\n", + "\tTransition: (s=[ 0.02008417 0.20824874 -0.1252646 -0.52014269], a=0, r=1.0, s_next=[ 0.02424915 0.01509216 -0.13566745 -0.26940956], done=False, info={'V': tensor([0.1929], grad_fn=), 'action_logprob': tensor(-0.6945, grad_fn=), 'entropy': tensor(0.6931, grad_fn=)})\n", + "\tTransition: (s=[ 0.02424915 0.01509216 -0.13566745 -0.26940956], a=0, r=1.0, s_next=[ 0.02455099 -0.17785921 -0.14105565 -0.02240595], done=False, info={'V': tensor([0.0410], grad_fn=), 'action_logprob': tensor(-0.6941, grad_fn=), 'entropy': tensor(0.6931, grad_fn=)})\n", + "\tTransition: (s=[ 0.02455099 -0.17785921 -0.14105565 -0.02240595], a=0, r=1.0, s_next=[ 0.02099381 -0.37070612 -0.14150376 0.22265843], done=False, info={'V': tensor([-0.1224], grad_fn=), 'action_logprob': tensor(-0.6937, grad_fn=), 'entropy': tensor(0.6931, grad_fn=)})\n", + "\tTransition: (s=[ 0.02099381 -0.37070612 -0.14150376 0.22265843], a=1, r=1.0, s_next=[ 0.01357968 -0.17387506 -0.1370506 -0.11110058], done=False, info={'V': tensor([-0.2809], grad_fn=), 'action_logprob': tensor(-0.6930, grad_fn=), 'entropy': tensor(0.6931, grad_fn=)})\n", + "\tTransition: (s=[ 0.01357968 -0.17387506 -0.1370506 -0.11110058], a=1, r=1.0, s_next=[ 0.01010218 0.02291746 -0.13927261 -0.44368831], done=False, info={'V': tensor([-0.1078], grad_fn=), 'action_logprob': tensor(-0.6925, grad_fn=), 'entropy': tensor(0.6931, grad_fn=)})\n", + "\tTransition: (s=[ 0.01010218 0.02291746 -0.13927261 -0.44368831], a=0, r=1.0, s_next=[ 0.01056053 -0.16998749 -0.14814637 -0.19794657], done=False, info={'V': tensor([0.0650], grad_fn=), 'action_logprob': tensor(-0.6944, grad_fn=), 'entropy': tensor(0.6931, grad_fn=)})\n", + "\tTransition: (s=[ 0.01056053 -0.16998749 -0.14814637 -0.19794657], a=0, r=1.0, s_next=[ 0.00716078 -0.36271415 -0.15210531 0.04458096], done=False, info={'V': tensor([-0.0938], grad_fn=), 'action_logprob': tensor(-0.6940, grad_fn=), 'entropy': tensor(0.6931, grad_fn=)})\n", + "\tTransition: (s=[ 0.00716078 -0.36271415 -0.15210531 0.04458096], a=1, r=1.0, s_next=[-9.35003999e-05 -1.65775404e-01 -1.51213686e-01 -2.91963188e-01], done=False, info={'V': tensor([-0.2550], grad_fn=), 'action_logprob': tensor(-0.6927, grad_fn=), 'entropy': tensor(0.6931, grad_fn=)})\n", + "\tTransition: (s=[-9.35003999e-05 -1.65775404e-01 -1.51213686e-01 -2.91963188e-01], a=1, r=1.0, s_next=[-0.00340901 0.03114278 -0.15705295 -0.6282575 ], done=False, info={'V': tensor([-0.0788], grad_fn=), 'action_logprob': tensor(-0.6921, grad_fn=), 'entropy': tensor(0.6931, grad_fn=)})\n", + "\tTransition: (s=[-0.00340901 0.03114278 -0.15705295 -0.6282575 ], a=0, r=1.0, s_next=[-0.00278615 -0.16147909 -0.1696181 -0.38886472], done=False, info={'V': tensor([0.0852], grad_fn=), 'action_logprob': tensor(-0.6948, grad_fn=), 'entropy': tensor(0.6931, grad_fn=)})\n", + "\tTransition: (s=[-0.00278615 -0.16147909 -0.1696181 -0.38886472], a=0, r=1.0, s_next=[-0.00601573 -0.35383869 -0.17739539 -0.15409501], done=False, info={'V': tensor([-0.0643], grad_fn=), 'action_logprob': tensor(-0.6944, grad_fn=), 'entropy': tensor(0.6931, grad_fn=)})\n", + "\tTransition: (s=[-0.00601573 -0.35383869 -0.17739539 -0.15409501], a=1, r=1.0, s_next=[-0.01309251 -0.156679 -0.18047729 -0.49707454], done=False, info={'V': tensor([-0.2229], grad_fn=), 'action_logprob': tensor(-0.6923, grad_fn=), 'entropy': tensor(0.6931, grad_fn=)})\n", + "\tTransition: (s=[-0.01309251 -0.156679 -0.18047729 -0.49707454], a=0, r=1.0, s_next=[-0.01622609 -0.34885867 -0.19041878 -0.26625981], done=False, info={'V': tensor([-0.0490], grad_fn=), 'action_logprob': tensor(-0.6947, grad_fn=), 'entropy': tensor(0.6931, grad_fn=)})\n", + "\tTransition: (s=[-0.01622609 -0.34885867 -0.19041878 -0.26625981], a=1, r=1.0, s_next=[-0.02320326 -0.15160152 -0.19574398 -0.61244284], done=False, info={'V': tensor([-0.2043], grad_fn=), 'action_logprob': tensor(-0.6920, grad_fn=), 'entropy': tensor(0.6931, grad_fn=)})\n", + "\tTransition: (s=[-0.02320326 -0.15160152 -0.19574398 -0.61244284], a=1, r=1.0, s_next=[-0.02623529 0.0456394 -0.20799284 -0.95983615], done=False, info={'V': tensor([-0.0344], grad_fn=), 'action_logprob': tensor(-0.6914, grad_fn=), 'entropy': tensor(0.6931, grad_fn=)})\n", + "\tTransition: (s=[-0.02623529 0.0456394 -0.20799284 -0.95983615], a=0, r=1.0, s_next=[-0.0253225 -0.14617206 -0.22718956 -0.73902998], done=True, info={'V': tensor([0.1061], grad_fn=), 'action_logprob': tensor(-0.6953, grad_fn=), 'entropy': tensor(0.6931, grad_fn=), 'init_observation': array([-0.04892357, 0.02011271, 0.02775732, -0.04547827])})\n", + "\tTransition: (s=[-0.04892357 0.02011271 0.02775732 -0.04547827], a=1, r=1.0, s_next=[-0.04852131 0.21482587 0.02684775 -0.3292759 ], done=False, info={'V': tensor([0.0196], grad_fn=), 'action_logprob': tensor(-0.6931, grad_fn=), 'entropy': tensor(0.6931, grad_fn=)})\n", + "\tTransition: (s=[-0.04852131 0.21482587 0.02684775 -0.3292759 ], a=1, r=1.0, s_next=[-0.0442248 0.40955554 0.02026223 -0.61337285], done=False, info={'V': tensor([0.1866], grad_fn=), 'action_logprob': tensor(-0.6926, grad_fn=), 'entropy': tensor(0.6931, grad_fn=)})\n", + "\tTransition: (s=[-0.0442248 0.40955554 0.02026223 -0.61337285], a=1, r=1.0, s_next=[-0.03603368 0.60438856 0.00799478 -0.89960571], done=False, info={'V': tensor([0.3310], grad_fn=), 'action_logprob': tensor(-0.6922, grad_fn=), 'entropy': tensor(0.6931, grad_fn=)})\n", + "\tTransition: (s=[-0.03603368 0.60438856 0.00799478 -0.89960571], a=1, r=1.0, s_next=[-0.02394591 0.79940126 -0.00999734 -1.18976497], done=False, info={'V': tensor([0.4439], grad_fn=), 'action_logprob': tensor(-0.6920, grad_fn=), 'entropy': tensor(0.6931, grad_fn=)})\n", + "\tTransition: (s=[-0.02394591 0.79940126 -0.00999734 -1.18976497], a=0, r=1.0, s_next=[-0.00795789 0.6044103 -0.03379264 -0.90023231], done=False, info={'V': tensor([0.5263], grad_fn=), 'action_logprob': tensor(-0.6944, grad_fn=), 'entropy': tensor(0.6931, grad_fn=)})\n", + "\tTransition: (s=[-0.00795789 0.6044103 -0.03379264 -0.90023231], a=0, r=1.0, s_next=[ 0.00413032 0.40976216 -0.05179728 -0.61835993], done=False, info={'V': tensor([0.4412], grad_fn=), 'action_logprob': tensor(-0.6944, grad_fn=), 'entropy': tensor(0.6931, grad_fn=)})\n", + "\tTransition: (s=[ 0.00413032 0.40976216 -0.05179728 -0.61835993], a=0, r=1.0, s_next=[ 0.01232556 0.21540054 -0.06416448 -0.3424301 ], done=False, info={'V': tensor([0.3279], grad_fn=), 'action_logprob': tensor(-0.6943, grad_fn=), 'entropy': tensor(0.6931, grad_fn=)})\n", + "\tTransition: (s=[ 0.01232556 0.21540054 -0.06416448 -0.3424301 ], a=0, r=1.0, s_next=[ 0.01663357 0.02124735 -0.07101308 -0.07065104], done=False, info={'V': tensor([0.1857], grad_fn=), 'action_logprob': tensor(-0.6940, grad_fn=), 'entropy': tensor(0.6931, grad_fn=)})\n", + "\tTransition: (s=[ 0.01663357 0.02124735 -0.07101308 -0.07065104], a=0, r=1.0, s_next=[ 0.01705852 -0.17278846 -0.07242611 0.19880881], done=False, info={'V': tensor([0.0220], grad_fn=), 'action_logprob': tensor(-0.6936, grad_fn=), 'entropy': tensor(0.6931, grad_fn=)})\n", + "\tTransition: (s=[ 0.01705852 -0.17278846 -0.07242611 0.19880881], a=0, r=1.0, s_next=[ 0.01360275 -0.3668038 -0.06844993 0.46779419], done=False, info={'V': tensor([-0.1458], grad_fn=), 'action_logprob': tensor(-0.6931, grad_fn=), 'entropy': tensor(0.6931, grad_fn=)})\n", + "\tTransition: (s=[ 0.01360275 -0.3668038 -0.06844993 0.46779419], a=0, r=1.0, s_next=[ 0.00626667 -0.5608953 -0.05909405 0.73814111], done=False, info={'V': tensor([-0.2983], grad_fn=), 'action_logprob': tensor(-0.6926, grad_fn=), 'entropy': tensor(0.6931, grad_fn=)})\n", + "\tTransition: (s=[ 0.00626667 -0.5608953 -0.05909405 0.73814111], a=0, r=1.0, s_next=[-0.00495123 -0.75515355 -0.04433122 1.0116563 ], done=False, info={'V': tensor([-0.4230], grad_fn=), 'action_logprob': tensor(-0.6923, grad_fn=), 'entropy': tensor(0.6931, grad_fn=)})\n", + "\tTransition: (s=[-0.00495123 -0.75515355 -0.04433122 1.0116563 ], a=1, r=1.0, s_next=[-0.0200543 -0.55946902 -0.0240981 0.70538879], done=False, info={'V': tensor([-0.5179], grad_fn=), 'action_logprob': tensor(-0.6941, grad_fn=), 'entropy': tensor(0.6931, grad_fn=)})\n", + "\tTransition: (s=[-0.0200543 -0.55946902 -0.0240981 0.70538879], a=0, r=1.0, s_next=[-0.03124368 -0.75424893 -0.00999032 0.99038966], done=False, info={'V': tensor([-0.4206], grad_fn=), 'action_logprob': tensor(-0.6923, grad_fn=), 'entropy': tensor(0.6931, grad_fn=)})\n", + "\tTransition: (s=[-0.03124368 -0.75424893 -0.00999032 0.99038966], a=1, r=1.0, s_next=[-0.04632866 -0.5589947 0.00981747 0.69458582], done=False, info={'V': tensor([-0.5165], grad_fn=), 'action_logprob': tensor(-0.6942, grad_fn=), 'entropy': tensor(0.6931, grad_fn=)})\n", + "\tTransition: (s=[-0.04632866 -0.5589947 0.00981747 0.69458582], a=1, r=1.0, s_next=[-0.05750856 -0.36401029 0.02370919 0.40500959], done=False, info={'V': tensor([-0.4183], grad_fn=), 'action_logprob': tensor(-0.6941, grad_fn=), 'entropy': tensor(0.6931, grad_fn=)})\n", + "\tTransition: (s=[-0.05750856 -0.36401029 0.02370919 0.40500959], a=1, r=1.0, s_next=[-0.06478876 -0.16923245 0.03180938 0.1198948 ], done=False, info={'V': tensor([-0.2890], grad_fn=), 'action_logprob': tensor(-0.6939, grad_fn=), 'entropy': tensor(0.6931, grad_fn=)})\n", + "\tTransition: (s=[-0.06478876 -0.16923245 0.03180938 0.1198948 ], a=1, r=1.0, s_next=[-0.06817341 0.02541965 0.03420728 -0.16258527], done=False, info={'V': tensor([-0.1320], grad_fn=), 'action_logprob': tensor(-0.6934, grad_fn=), 'entropy': tensor(0.6931, grad_fn=)})\n", + "\tTransition: (s=[-0.06817341 0.02541965 0.03420728 -0.16258527], a=0, r=1.0, s_next=[-0.06766502 -0.17017487 0.03095557 0.14068986], done=False, info={'V': tensor([0.0385], grad_fn=), 'action_logprob': tensor(-0.6934, grad_fn=), 'entropy': tensor(0.6931, grad_fn=)})\n", + "\tTransition: (s=[-0.06766502 -0.17017487 0.03095557 0.14068986], a=0, r=1.0, s_next=[-0.07106852 -0.36572619 0.03376937 0.44297578], done=False, info={'V': tensor([-0.1353], grad_fn=), 'action_logprob': tensor(-0.6928, grad_fn=), 'entropy': tensor(0.6931, grad_fn=)})\n", + "\tTransition: (s=[-0.07106852 -0.36572619 0.03376937 0.44297578], a=1, r=1.0, s_next=[-0.07838304 -0.17109794 0.04262888 0.16112616], done=False, info={'V': tensor([-0.2922], grad_fn=), 'action_logprob': tensor(-0.6939, grad_fn=), 'entropy': tensor(0.6931, grad_fn=)})\n", + "\tTransition: (s=[-0.07838304 -0.17109794 0.04262888 0.16112616], a=0, r=1.0, s_next=[-0.081805 -0.36680342 0.04585141 0.4669468 ], done=False, info={'V': tensor([-0.1382], grad_fn=), 'action_logprob': tensor(-0.6928, grad_fn=), 'entropy': tensor(0.6931, grad_fn=)})\n", + "\tTransition: (s=[-0.081805 -0.36680342 0.04585141 0.4669468 ], a=0, r=1.0, s_next=[-0.08914107 -0.56254219 0.05519034 0.77372196], done=False, info={'V': tensor([-0.2937], grad_fn=), 'action_logprob': tensor(-0.6923, grad_fn=), 'entropy': tensor(0.6931, grad_fn=)})" + ] + }, + "execution_count": 45, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "segment = D[0]\n", + "segment" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 92, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 94, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 38, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Trajectory: \n", + "\tTransition: (s=1.0, a=10, r=0.1, s_next=2.0, done=False, info={})\n", + "\tTransition: (s=2.0, a=20, r=0.2, s_next=3.0, done=False, info={})\n", + "\tTransition: (s=3.0, a=30, r=0.3, s_next=4.0, done=True, info={})" + ] + }, + "execution_count": 38, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 54, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[array([-0.0253225 , -0.14617206, -0.22718956, -0.73902998]),\n", + " array([-0.08914107, -0.56254219, 0.05519034, 0.77372196])]" + ] + }, + "execution_count": 54, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "final_state_from_segment(segment)" + ] + }, + { + "cell_type": "code", + "execution_count": 96, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "test_bootstrapped_returns_from_segment()" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "True" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "np.allclose([0.6, 0.5, 0.3, 1.1, 0.6], [0.6, 0.5, 0.3, 1.1, 0.6])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [], + "source": [ + "a = torch.randn(3, 4).to('cuda')\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, { "cell_type": "code", "execution_count": 219, diff --git a/examples/policy_gradient/a2c/main.py b/examples/policy_gradient/a2c/main.py index 3dd33c96..1171bfed 100644 --- a/examples/policy_gradient/a2c/main.py +++ b/examples/policy_gradient/a2c/main.py @@ -6,5 +6,4 @@ run_experiment(worker_class=ExperimentWorker, master_class=ExperimentMaster, - max_num_worker=None, - daemonic_worker=None) + num_worker=100) diff --git a/examples/policy_gradient/a2c/model.py b/examples/policy_gradient/a2c/model.py new file mode 100644 index 00000000..c75c8d89 --- /dev/null +++ b/examples/policy_gradient/a2c/model.py @@ -0,0 +1,231 @@ +import numpy as np + +import torch +import torch.optim as optim +import torch.nn as nn +import torch.nn.functional as F + +from torch.nn.utils import clip_grad_norm_ + +from lagom.networks import BaseNetwork +from lagom.networks import make_fc +from lagom.networks import ortho_init +from lagom.networks import linear_lr_scheduler + +from lagom.policies import BasePolicy +from lagom.policies import CategoricalHead +from lagom.policies import DiagGaussianHead +from lagom.policies import constraint_action + +from lagom.value_functions import StateValueHead + +from lagom.transform import Standardize + +from lagom.history.metrics import final_state_from_segment +from lagom.history.metrics import terminal_state_from_segment +from lagom.history.metrics import bootstrapped_returns_from_segment + +from lagom.agents import BaseAgent + + +class MLP(BaseNetwork): + def make_params(self, config): + self.feature_layers = make_fc(self.env_spec.observation_space.flat_dim, config['network.hidden_sizes']) + + def init_params(self, config): + for layer in self.feature_layers: + ortho_init(layer, nonlinearity='tanh', constant_bias=0.0) + + def reset(self, config, **kwargs): + pass + + def forward(self, x): + for layer in self.feature_layers: + x = torch.tanh(layer(x)) + + return x + + +class Policy(BasePolicy): + def make_networks(self, config): + self.feature_network = MLP(config, self.device, env_spec=self.env_spec) + feature_dim = config['network.hidden_sizes'][-1] + + if self.env_spec.control_type == 'Discrete': + self.action_head = CategoricalHead(config, self.device, feature_dim, self.env_spec) + elif self.env_spec.control_type == 'Continuous': + self.action_head = DiagGaussianHead(config, + self.device, + feature_dim, + self.env_spec, + min_std=config['agent.min_std'], + std_style=config['agent.std_style'], + constant_std=config['agent.constant_std'], + std_state_dependent=config['agent.std_state_dependent'], + init_std=config['agent.init_std']) + self.V_head = StateValueHead(config, self.device, feature_dim) + + @property + def recurrent(self): + return False + + def reset(self, config, **kwargs): + pass + + def __call__(self, x, out_keys=['action', 'V'], info={}, **kwargs): + out = {} + + features = self.feature_network(x) + action_dist = self.action_head(features) + + action = action_dist.sample().detach()################################ + out['action'] = action + + V = self.V_head(features) + out['V'] = V + + if 'action_logprob' in out_keys: + out['action_logprob'] = action_dist.log_prob(action) + if 'entropy' in out_keys: + out['entropy'] = action_dist.entropy() + if 'perplexity' in out_keys: + out['perplexity'] = action_dist.perplexity() + + return out + +class Agent(BaseAgent): + r"""`Advantage Actor-Critic`_ (A2C). + + The main difference of A2C is to use bootstrapping for estimating the advantage function and training value function. + + .. _Advantage Actor-Critic: + https://arxiv.org/abs/1602.01783 + + Like `OpenAI baselines` we use fixed-length segments of experiment to compute returns and advantages. + + .. _OpenAI baselines: + https://blog.openai.com/baselines-acktr-a2c/ + + .. note:: + + Use :class:`SegmentRunner` to collect data, not :class:`TrajectoryRunner` + + """ + def make_modules(self, config): + self.policy = Policy(config, self.env_spec, self.device) + + def prepare(self, config, **kwargs): + self.total_T = 0 + self.optimizer = optim.Adam(self.policy.parameters(), lr=config['algo.lr']) + if config['algo.use_lr_scheduler']: + if 'train.iter' in config: + self.lr_scheduler = linear_lr_scheduler(self.optimizer, config['train.iter'], 'iteration-based') + elif 'train.timestep' in config: + self.lr_scheduler = linear_lr_scheduler(self.optimizer, config['train.timestep']+1, 'timestep-based') + else: + self.lr_scheduler = None + + def reset(self, config, **kwargs): + pass + + def choose_action(self, obs, info={}): + obs = torch.from_numpy(np.asarray(obs)).float().to(self.device) + + out = self.policy(obs, out_keys=['action', 'action_logprob', 'V', 'entropy'], info=info) + + # sanity check for NaN + if torch.any(torch.isnan(out['action'])): + while True: + print('NaN !') + if self.env_spec.control_type == 'Continuous': + out['action'] = constraint_action(self.env_spec, out['action']) + + return out + + def learn(self, D, info={}): + batch_policy_loss = [] + batch_entropy_loss = [] + batch_value_loss = [] + batch_total_loss = [] + + for segment in D: + logprobs = segment.all_info('action_logprob') + entropies = segment.all_info('entropy') + + final_states = final_state_from_segment(segment) + final_states = torch.tensor(final_states).float().to(self.device) + all_V_last = self.policy(final_states)['V'].cpu().detach().numpy() + Qs = bootstrapped_returns_from_segment(segment, all_V_last, self.config['algo.gamma']) + # Standardize: encourage/discourage half of performed actions + if self.config['agent.standardize_Q']: + Qs = Standardize()(Qs, -1).tolist() + + Vs = segment.all_info('V') + terminal_states = terminal_state_from_segment(segment) + if len(terminal_states) > 0: + terminal_states = torch.tensor(terminal_states).float().to(self.device) + all_V_terminal = self.policy(terminal_states)['V'] + else: + all_V_terminal = [] + + As = [Q - V.item() for Q, V in zip(Qs, Vs)] + if self.config['agent.standardize_adv']: + As = Standardize()(As, -1).tolist() + + policy_loss = [] + entropy_loss = [] + value_loss = [] + for logprob, entropy, A, Q, V in zip(logprobs, entropies, As, Qs, Vs): + policy_loss.append(-logprob*A) + entropy_loss.append(-entropy) + value_loss.append(F.mse_loss(V, torch.tensor(Q).view_as(V).to(V.device))) + for V_terminal in all_V_terminal: + value_loss.append(F.mse_loss(V_terminal, torch.tensor(0.0).view_as(V).to(V.device))) + + policy_loss = torch.stack(policy_loss).mean() + entropy_loss = torch.stack(entropy_loss).mean() + value_loss = torch.stack(value_loss).mean() + + entropy_coef = self.config['agent.entropy_coef'] + value_coef = self.config['agent.value_coef'] + total_loss = policy_loss + value_coef*value_loss + entropy_coef*entropy_loss + + batch_policy_loss.append(policy_loss) + batch_entropy_loss.append(entropy_loss) + batch_value_loss.append(value_loss) + batch_total_loss.append(total_loss) + + policy_loss = torch.stack(batch_policy_loss).mean() + entropy_loss = torch.stack(batch_entropy_loss).mean() + value_loss = torch.stack(batch_value_loss).mean() + loss = torch.stack(batch_total_loss).mean() + + self.optimizer.zero_grad() + loss.backward() + + if self.config['agent.max_grad_norm'] is not None: + clip_grad_norm_(self.parameters(), self.config['agent.max_grad_norm']) + + if self.lr_scheduler is not None: + if self.lr_scheduler.mode == 'iteration-based': + self.lr_scheduler.step() + elif self.lr_scheduler.mode == 'timestep-based': + self.lr_scheduler.step(self.total_T) + + self.optimizer.step() + + self.total_T += sum([segment.T for segment in D]) + + out = {} + out['loss'] = loss.item() + out['policy_loss'] = policy_loss.item() + out['entropy_loss'] = entropy_loss.item() + out['value_loss'] = value_loss.item() + if self.lr_scheduler is not None: + out['current_lr'] = self.lr_scheduler.get_lr() + + return out + + @property + def recurrent(self): + pass diff --git a/examples/policy_gradient/a2c/policy.py b/examples/policy_gradient/a2c/policy.py deleted file mode 100644 index 43982b5f..00000000 --- a/examples/policy_gradient/a2c/policy.py +++ /dev/null @@ -1,61 +0,0 @@ -import numpy as np - -import torch -import torch.nn as nn -import torch.nn.functional as F - -from lagom.core.networks import BaseNetwork -from lagom.core.networks import BaseRNN -from lagom.core.networks import LayerNormLSTMCell -from lagom.core.networks import make_fc -from lagom.core.networks import ortho_init - - -class Network(BaseNetwork): - def make_params(self, config): - self.layers = make_fc(input_dim=self.env_spec.observation_space.flat_dim, - hidden_sizes=config['network.hidden_sizes']) - self.last_feature_dim = config['network.hidden_sizes'][-1] - - def init_params(self, config): - for layer in self.layers: - ortho_init(layer, nonlinearity='tanh', constant_bias=0.0) - - def forward(self, x): - for layer in self.layers: - x = torch.tanh(layer(x)) - - return x - - -class LSTM(BaseRNN): - def make_params(self, config): - # nn.LSTMCell - self.rnn = LayerNormLSTMCell(input_size=self.env_spec.observation_space.flat_dim, - hidden_size=config['network.hidden_sizes'][0]) # TODO: support multi-layer - self.last_feature_dim = config['network.hidden_sizes'][-1] - - def init_params(self, config): - ortho_init(self.rnn, nonlinearity=None, weight_scale=1.0, constant_bias=0.0) - - def init_hidden_states(self, config, batch_size, **kwargs): - h = torch.zeros(batch_size, config['network.hidden_sizes'][0]) - h = h.to(self.device) - c = torch.zeros_like(h) - - return [h, c] - - def rnn_forward(self, x, hidden_states, mask=None, **kwargs): - if mask is not None: - mask = mask.to(self.device) - - h, c = hidden_states - h = h*mask - c = c*mask - hidden_states = [h, c] - - h, c = self.rnn(x, hidden_states) - - out = {'output': h, 'hidden_states': [h, c]} - - return out diff --git a/examples/policy_gradient/a2c_agent.py b/examples/policy_gradient/a2c_agent.py deleted file mode 100644 index 31ca7529..00000000 --- a/examples/policy_gradient/a2c_agent.py +++ /dev/null @@ -1,152 +0,0 @@ -import numpy as np - -import torch -import torch.nn as nn -import torch.nn.functional as F -from torch.nn.utils import clip_grad_norm_ - -from .base_agent import BaseAgent - -from lagom.core.transform import Standardize - - -class A2CAgent(BaseAgent): - r"""`Advantage Actor-Critic`_ (A2C) with option to use Generalized Advantage Estimate (GAE) - - The main difference of A2C is to use bootstrapping for estimating the advantage function and training value function. - - .. _Advantage Actor-Critic: - https://arxiv.org/abs/1602.01783 - - Like `OpenAI baselines` we use fixed-length segments of experiment to compute returns and advantages. - - .. _OpenAI baselines: - https://blog.openai.com/baselines-acktr-a2c/ - - .. note:: - - Use :class:`SegmentRunner` to collect data, not :class:`TrajectoryRunner` - - """ - def __init__(self, config, device, policy, optimizer, **kwargs): - self.policy = policy - self.optimizer = optimizer - - super().__init__(config, device, **kwargs) - - # accumulated trained timesteps - self.total_T = 0 - - def choose_action(self, obs, info={}): - if not torch.is_tensor(obs): - obs = np.asarray(obs) - assert obs.ndim >= 2, f'expected at least 2-dim for batched data, got {obs.ndim}' - obs = torch.from_numpy(obs).float().to(self.device) - - if self.policy.recurrent and self.info['reset_rnn_states']: - self.policy.reset_rnn_states(batch_size=obs.size(0)) - self.info['reset_rnn_states'] = False # Done, turn off - - out_policy = self.policy(obs, - out_keys=['action', 'action_logprob', 'state_value', - 'entropy', 'perplexity'], - info=info) - - return out_policy - - def learn(self, D, info={}): - batch_policy_loss = [] - batch_entropy_loss = [] - batch_value_loss = [] - batch_total_loss = [] - - for segment in D: - logprobs = segment.all_info('action_logprob') - entropies = segment.all_info('entropy') - Qs = segment.all_bootstrapped_discounted_returns - - # Standardize: encourage/discourage half of performed actions - if self.config['agent.standardize_Q']: - Qs = Standardize()(Qs).tolist() - - # State values - Vs, finals = segment.all_V - final_Vs, final_dones = zip(*finals) - assert len(Vs) == len(segment.transitions) - - # Advantage estimates - As = [Q - V.item() for Q, V in zip(Qs, Vs)] - if self.config['agent.standardize_adv']: - As = Standardize()(As).tolist() - - # Estimate policy gradient for all time steps and record all losses - policy_loss = [] - entropy_loss = [] - value_loss = [] - for logprob, entropy, A, Q, V in zip(logprobs, entropies, As, Qs, Vs): - policy_loss.append(-logprob*A) - entropy_loss.append(-entropy) - value_loss.append(F.mse_loss(V, torch.tensor(Q).view_as(V).to(V.device))) - for final_V, final_done in zip(final_Vs, final_dones): # learn terminal state value as zero - if final_done: - value_loss.append(F.mse_loss(final_V, torch.tensor(0.0).view_as(V).to(V.device))) - - # Average losses over all time steps - policy_loss = torch.stack(policy_loss).mean() - entropy_loss = torch.stack(entropy_loss).mean() - value_loss = torch.stack(value_loss).mean() - - # Calculate total loss - entropy_coef = self.config['agent.entropy_coef'] - value_coef = self.config['agent.value_coef'] - total_loss = policy_loss + value_coef*value_loss + entropy_coef*entropy_loss - - # Record all losses - batch_policy_loss.append(policy_loss) - batch_entropy_loss.append(entropy_loss) - batch_value_loss.append(value_loss) - batch_total_loss.append(total_loss) - - # Average loss over list of Segment - policy_loss = torch.stack(batch_policy_loss).mean() - entropy_loss = torch.stack(batch_entropy_loss).mean() - value_loss = torch.stack(batch_value_loss).mean() - loss = torch.stack(batch_total_loss).mean() - - # Train with estimated policy gradient - self.optimizer.zero_grad() - loss.backward() - - if self.config['agent.max_grad_norm'] is not None: - clip_grad_norm_(parameters=self.policy.network.parameters(), - max_norm=self.config['agent.max_grad_norm'], - norm_type=2) - - if hasattr(self, 'lr_scheduler'): - if 'train.iter' in self.config: # iteration-based - self.lr_scheduler.step() - elif 'train.timestep' in self.config: # timestep-based - self.lr_scheduler.step(self.total_T) - else: - raise KeyError('expected `train.iter` or `train.timestep` in config, but got none of them') - - self.optimizer.step() - - # Accumulate trained timesteps - self.total_T += sum([segment.T for segment in D]) - - out = {} - out['loss'] = loss.item() - out['policy_loss'] = policy_loss.item() - out['entropy_loss'] = entropy_loss.item() - out['value_loss'] = value_loss.item() - if hasattr(self, 'lr_scheduler'): - out['current_lr'] = self.lr_scheduler.get_lr() - - return out - - def save(self, f): - self.policy.network.save(f) - - def load(self, f): - self.policy.network.load(f) diff --git a/examples/policy_gradient/reinforce/algo.py b/examples/policy_gradient/reinforce/algo.py index 96a285c7..15ce5c74 100644 --- a/examples/policy_gradient/reinforce/algo.py +++ b/examples/policy_gradient/reinforce/algo.py @@ -58,7 +58,6 @@ def __call__(self, config, seed, device): eps=env.eps, constant_obs_mean=env.obs_runningavg.mu, constant_obs_std=env.obs_runningavg.sigma) - env_spec = EnvSpec(env) agent = Agent(config, env_spec, device) diff --git a/examples/policy_gradient/reinforce/experiment.py b/examples/policy_gradient/reinforce/experiment.py index ac6a66b3..e3f1ef89 100644 --- a/examples/policy_gradient/reinforce/experiment.py +++ b/examples/policy_gradient/reinforce/experiment.py @@ -21,17 +21,17 @@ def make_configs(self): configurator.fixed('cuda', True) # whether to use GPU - configurator.fixed('env.id', 'Pendulum-v0') + configurator.fixed('env.id', 'HalfCheetah-v2') configurator.fixed('env.standardize', True) # whether to use VecStandardize configurator.fixed('network.recurrent', False) - configurator.fixed('network.hidden_sizes', [32]) # TODO: [64, 64] + configurator.fixed('network.hidden_sizes', [64, 64]) # TODO: [64, 64] configurator.fixed('algo.lr', 1e-3) configurator.fixed('algo.use_lr_scheduler', True) configurator.fixed('algo.gamma', 0.99) - configurator.fixed('agent.standardize_Q', False) # whether to standardize discounted returns + configurator.fixed('agent.standardize_Q', True) # whether to standardize discounted returns configurator.fixed('agent.max_grad_norm', 0.5) # grad clipping, set None to turn off configurator.fixed('agent.entropy_coef', 0.01) # only for continuous control @@ -55,7 +55,7 @@ def make_configs(self): return list_config def make_seeds(self): - list_seed = [1]#[209652396, 398764591, 924231285, 1478610112, 441365315] + list_seed = [209652396, 398764591, 924231285, 1478610112, 441365315] return list_seed diff --git a/examples/policy_gradient/vpg/algo.py b/examples/policy_gradient/vpg/algo.py index ad744c49..6938b42e 100644 --- a/examples/policy_gradient/vpg/algo.py +++ b/examples/policy_gradient/vpg/algo.py @@ -3,54 +3,41 @@ from itertools import count import numpy as np - import torch -import torch.nn as nn -import torch.nn.functional as F -import torch.optim as optim -from lagom import set_global_seeds +from lagom import Logger +from lagom.utils import pickle_dump +from lagom.utils import set_global_seeds + from lagom import BaseAlgorithm -from lagom import pickle_dump from lagom.envs import make_gym_env from lagom.envs import make_vec_env from lagom.envs import EnvSpec from lagom.envs.vec_env import SerialVecEnv -from lagom.envs.vec_env import ParallelVecEnv from lagom.envs.vec_env import VecStandardize -from lagom.core.policies import CategoricalPolicy -from lagom.core.policies import GaussianPolicy - from lagom.runner import TrajectoryRunner -from lagom.agents import VPGAgent - +from model import Agent from engine import Engine -from policy import Network -from policy import LSTM class Algorithm(BaseAlgorithm): - def __call__(self, config, seed, device_str): + def __call__(self, config, seed, device): set_global_seeds(seed) - device = torch.device(device_str) logdir = Path(config['log.dir']) / str(config['ID']) / str(seed) - - # Environment related + env = make_vec_env(vec_env_class=SerialVecEnv, make_env=make_gym_env, env_id=config['env.id'], num_env=config['train.N'], # batched environment - init_seed=seed, - rolling=False) + init_seed=seed) eval_env = make_vec_env(vec_env_class=SerialVecEnv, make_env=make_gym_env, env_id=config['env.id'], num_env=config['eval.N'], - init_seed=seed, - rolling=False) + init_seed=seed) if config['env.standardize']: # running averages of observation and reward env = VecStandardize(venv=env, use_obs=True, @@ -59,74 +46,24 @@ def __call__(self, config, seed, device_str): clip_reward=10., gamma=0.99, eps=1e-8) - eval_env = VecStandardize(venv=eval_env, # remember to synchronize running averages during evaluation !!! + eval_env = VecStandardize(venv=eval_env, use_obs=True, use_reward=False, # do not process rewards, no training clip_obs=env.clip_obs, clip_reward=env.clip_reward, gamma=env.gamma, eps=env.eps, - constant_obs_mean=env.obs_runningavg.mu, # use current running average as constant + constant_obs_mean=env.obs_runningavg.mu, constant_obs_std=env.obs_runningavg.sigma) env_spec = EnvSpec(env) - # Network and policy - if config['network.recurrent']: - network = LSTM(config=config, device=device, env_spec=env_spec) - else: - network = Network(config=config, device=device, env_spec=env_spec) - if env_spec.control_type == 'Discrete': - policy = CategoricalPolicy(config=config, - network=network, - env_spec=env_spec, - device=device, - learn_V=True) - elif env_spec.control_type == 'Continuous': - policy = GaussianPolicy(config=config, - network=network, - env_spec=env_spec, - device=device, - learn_V=True, - min_std=config['agent.min_std'], - std_style=config['agent.std_style'], - constant_std=config['agent.constant_std'], - std_state_dependent=config['agent.std_state_dependent'], - init_std=config['agent.init_std']) - - # Optimizer and learning rate scheduler - optimizer = optim.Adam(policy.network.parameters(), lr=config['algo.lr']) - if config['algo.use_lr_scheduler']: - if 'train.iter' in config: # iteration-based - max_epoch = config['train.iter'] - elif 'train.timestep' in config: # timestep-based - max_epoch = config['train.timestep'] + 1 # avoid zero lr in final iteration - lambda_f = lambda epoch: 1 - epoch/max_epoch - lr_scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_f) - - # Agent - kwargs = {'device': device} - if config['algo.use_lr_scheduler']: - kwargs['lr_scheduler'] = lr_scheduler - agent = VPGAgent(config=config, - policy=policy, - optimizer=optimizer, - **kwargs) + agent = Agent(config, env_spec, device) - # Runner - runner = TrajectoryRunner(agent=agent, - env=env, - gamma=config['algo.gamma']) - eval_runner = TrajectoryRunner(agent=agent, - env=eval_env, - gamma=1.0) + runner = TrajectoryRunner(config, agent, env) + eval_runner = TrajectoryRunner(config, agent, eval_env) - # Engine - engine = Engine(agent=agent, - runner=runner, - config=config, - eval_runner=eval_runner) + engine = Engine(agent, runner, config, eval_runner=eval_runner) - # Training and evaluation train_logs = [] eval_logs = [] for i in count(): @@ -135,9 +72,8 @@ def __call__(self, config, seed, device_str): elif 'train.timestep' in config and agent.total_T >= config['train.timestep']: # enough timesteps break - train_output = engine.train(n=i) + train_output = engine.train(i) - # Logging if i == 0 or (i+1) % config['log.record_interval'] == 0 or (i+1) % config['log.print_interval'] == 0: train_log = engine.log_train(train_output) @@ -149,7 +85,6 @@ def __call__(self, config, seed, device_str): train_logs.append(train_log) eval_logs.append(eval_log) - # Save all loggings pickle_dump(obj=train_logs, f=logdir/'train_logs', ext='.pkl') pickle_dump(obj=eval_logs, f=logdir/'eval_logs', ext='.pkl') diff --git a/examples/policy_gradient/vpg/engine.py b/examples/policy_gradient/vpg/engine.py index c90f7bac..fb45f84d 100644 --- a/examples/policy_gradient/vpg/engine.py +++ b/examples/policy_gradient/vpg/engine.py @@ -3,7 +3,7 @@ import torch from lagom import Logger -from lagom import color_str +from lagom.utils import color_str from lagom.engine import BaseEngine @@ -15,12 +15,10 @@ class Engine(BaseEngine): def train(self, n): - self.agent.policy.network.train() # train mode + self.agent.train() - # Collect a list of Trajectory D = self.runner(T=self.config['train.T']) - # Train agent with collected data out_agent = self.agent.learn(D) train_output = {} @@ -31,36 +29,32 @@ def train(self, n): return train_output def log_train(self, train_output, **kwargs): - # Unpack D = train_output['D'] out_agent = train_output['out_agent'] n = train_output['n'] - # Loggings - logger = Logger(name='train_logger') - logger.log('train_iteration', n+1) # starts from 1 - if self.config['algo.use_lr_scheduler']: - logger.log('current_lr', out_agent['current_lr']) - - logger.log('loss', out_agent['loss']) - logger.log('policy_loss', out_agent['policy_loss']) - logger.log('policy_entropy', -out_agent['entropy_loss']) # entropy: negative entropy loss - logger.log('value_loss', out_agent['value_loss']) + logger = Logger() + logger('train_iteration', n+1) # starts from 1 + if 'current_lr' in out_agent: + logger('current_lr', out_agent['current_lr']) + logger('loss', out_agent['loss']) + logger('policy_loss', out_agent['policy_loss']) + logger('policy_entropy', -out_agent['entropy_loss']) + logger('value_loss', out_agent['value_loss']) batch_returns = [sum(trajectory.all_r) for trajectory in D] - batch_discounted_returns = [trajectory.all_discounted_returns[0] for trajectory in D] + batch_discounted_returns = [trajectory.all_discounted_returns(self.config['algo.gamma'])[0] for trajectory in D] num_timesteps = sum([trajectory.T for trajectory in D]) - logger.log('num_trajectories', len(D)) - logger.log('num_timesteps', num_timesteps) - logger.log('accumulated_trained_timesteps', self.agent.total_T) - logger.log('average_return', np.mean(batch_returns)) - logger.log('average_discounted_return', np.mean(batch_discounted_returns)) - logger.log('std_return', np.std(batch_returns)) - logger.log('min_return', np.min(batch_returns)) - logger.log('max_return', np.max(batch_returns)) + logger('num_trajectories', len(D)) + logger('num_timesteps', num_timesteps) + logger('accumulated_trained_timesteps', self.agent.total_T) + logger('average_return', np.mean(batch_returns)) + logger('average_discounted_return', np.mean(batch_discounted_returns)) + logger('std_return', np.std(batch_returns)) + logger('min_return', np.min(batch_returns)) + logger('max_return', np.max(batch_returns)) - # Dump loggings if n == 0 or (n+1) % self.config['log.print_interval'] == 0: print('-'*50) logger.dump(keys=None, index=None, indent=0) @@ -69,16 +63,15 @@ def log_train(self, train_output, **kwargs): return logger.logs def eval(self, n): - self.agent.policy.network.eval() # evaluation mode + self.agent.eval() # Synchronize running average of observations for evaluation if self.config['env.standardize']: self.eval_runner.env.constant_obs_mean = self.runner.env.obs_runningavg.mu self.eval_runner.env.constant_obs_std = self.runner.env.obs_runningavg.sigma - # Collect a list of Trajectory T = self.eval_runner.env.T - D = self.eval_runner(T=T) + D = self.eval_runner(T) eval_output = {} eval_output['D'] = D @@ -88,29 +81,26 @@ def eval(self, n): return eval_output def log_eval(self, eval_output, **kwargs): - # Unpack D = eval_output['D'] n = eval_output['n'] T = eval_output['T'] - # Loggings - logger = Logger(name='eval_logger') + logger = Logger() batch_returns = [sum(trajectory.all_r) for trajectory in D] batch_T = [trajectory.T for trajectory in D] - logger.log('evaluation_iteration', n+1) - logger.log('num_trajectories', len(D)) - logger.log('max_allowed_horizon', T) - logger.log('average_horizon', np.mean(batch_T)) - logger.log('num_timesteps', np.sum(batch_T)) - logger.log('accumulated_trained_timesteps', self.agent.total_T) - logger.log('average_return', np.mean(batch_returns)) - logger.log('std_return', np.std(batch_returns)) - logger.log('min_return', np.min(batch_returns)) - logger.log('max_return', np.max(batch_returns)) - - # Dump loggings + logger('evaluation_iteration', n+1) + logger('num_trajectories', len(D)) + logger('max_allowed_horizon', T) + logger('average_horizon', np.mean(batch_T)) + logger('num_timesteps', np.sum(batch_T)) + logger('accumulated_trained_timesteps', self.agent.total_T) + logger('average_return', np.mean(batch_returns)) + logger('std_return', np.std(batch_returns)) + logger('min_return', np.min(batch_returns)) + logger('max_return', np.max(batch_returns)) + if n == 0 or (n+1) % self.config['log.print_interval'] == 0: print(color_str('+'*50, 'yellow', 'bold')) logger.dump(keys=None, index=None, indent=0) diff --git a/examples/policy_gradient/vpg/experiment.py b/examples/policy_gradient/vpg/experiment.py index c9aa6cf6..7e4d20fb 100644 --- a/examples/policy_gradient/vpg/experiment.py +++ b/examples/policy_gradient/vpg/experiment.py @@ -6,8 +6,11 @@ class ExperimentWorker(BaseExperimentWorker): + def prepare(self): + pass + def make_algo(self): - algo = Algorithm(name='Vanilla Policy Gradient') + algo = Algorithm() return algo @@ -21,8 +24,8 @@ def make_configs(self): configurator.fixed('env.id', 'HalfCheetah-v2') configurator.fixed('env.standardize', True) # whether to use VecStandardize - configurator.fixed('network.recurrent', True) - configurator.fixed('network.hidden_sizes', [8]) # TODO: [64, 64] + configurator.fixed('network.recurrent', False) + configurator.fixed('network.hidden_sizes', [64, 64]) # TODO: [64, 64] configurator.fixed('algo.lr', 1e-3) configurator.fixed('algo.use_lr_scheduler', True) @@ -46,7 +49,7 @@ def make_configs(self): configurator.fixed('eval.N', 10) # number of episodes to evaluate, do not specify T for complete episode configurator.fixed('log.record_interval', 100) # interval to record the logging - configurator.fixed('log.print_interval', 500) # interval to print the logging to screen + configurator.fixed('log.print_interval', 100) # interval to print the logging to screen configurator.fixed('log.dir', 'logs') # logging directory list_config = configurator.make_configs() @@ -58,5 +61,5 @@ def make_seeds(self): return list_seed - def process_algo_result(self, config, seed, result): - assert result is None + def process_results(self, results): + assert all([result is None for result in results]) diff --git a/examples/policy_gradient/vpg/main.ipynb b/examples/policy_gradient/vpg/main.ipynb index 88605e81..821d5249 100644 --- a/examples/policy_gradient/vpg/main.ipynb +++ b/examples/policy_gradient/vpg/main.ipynb @@ -16,208 +16,41 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "from lagom.value_functions import StateValueHead\n", + "\n", + "StateValueHead()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 4, "metadata": {}, "outputs": [ { - "ename": "ImportError", - "evalue": "libcudart.so.9.2: cannot open shared object file: No such file or directory", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mImportError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mnumpy\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0;32mimport\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 4\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0moptim\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0moptim\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnn\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mnn\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m~/anaconda3/envs/RL/lib/python3.7/site-packages/torch/__init__.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 82\u001b[0m \u001b[0;32mpass\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 83\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 84\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_C\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 85\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 86\u001b[0m __all__ += [name for name in dir(_C)\n", - "\u001b[0;31mImportError\u001b[0m: libcudart.so.9.2: cannot open shared object file: No such file or directory" + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/zuo/anaconda3/envs/RL/lib/python3.6/importlib/_bootstrap.py:219: RuntimeWarning: numpy.dtype size changed, may indicate binary incompatibility. Expected 96, got 88\n", + " return f(*args, **kwds)\n", + "/home/zuo/Code/lagom/lagom/vis/__init__.py:10: UserWarning: ImageViewer failed to import due to pyglet. \n", + " warnings.warn('ImageViewer failed to import due to pyglet. ')\n", + "/home/zuo/Code/lagom/lagom/envs/vec_env/vec_env.py:12: UserWarning: ImageViewer failed to import due to pyglet. \n", + " warnings.warn('ImageViewer failed to import due to pyglet. ')\n" ] } ], - "source": [ - "import numpy as np\n", - "\n", - "import torch\n", - "import torch.optim as optim\n", - "import torch.nn as nn\n", - "import torch.nn.functional as F\n", - "\n", - "from torch.nn.utils import clip_grad_norm_\n", - "\n", - "from lagom.networks import BaseNetwork\n", - "from lagom.networks import make_fc\n", - "from lagom.networks import ortho_init\n", - "from lagom.networks import linear_lr_scheduler\n", - "\n", - "from lagom.policies import BasePolicy\n", - "from lagom.policies import CategoricalHead\n", - "from lagom.policies import DiagGaussianHead\n", - "from lagom.policies import constraint_action\n", - "\n", - "from lagom.value\n", - "\n", - "from lagom.transform import Standardize\n", - "\n", - "from lagom.agents import BaseAgent\n", - "\n", - "\n", - "class MLP(BaseNetwork):\n", - " def make_params(self, config):\n", - " self.feature_layers = make_fc(self.env_spec.observation_space.flat_dim, config['network.hidden_sizes'])\n", - " \n", - " def init_params(self, config):\n", - " for layer in self.feature_layers:\n", - " ortho_init(layer, nonlinearity='tanh', constant_bias=0.0)\n", - " \n", - " def reset(self, config, **kwargs):\n", - " pass\n", - " \n", - " def forward(self, x):\n", - " for layer in self.feature_layers:\n", - " x = torch.tanh(layer(x))\n", - " \n", - " return x\n", - " \n", - " \n", - "class Policy(BasePolicy):\n", - " def make_networks(self, config):\n", - " self.feature_network = MLP(config, self.device, env_spec=self.env_spec)\n", - " feature_dim = config['network.hidden_sizes'][-1]\n", - " \n", - " if self.env_spec.control_type == 'Discrete':\n", - " self.action_head = CategoricalHead(config, self.device, feature_dim, self.env_spec)\n", - " elif self.env_spec.control_type == 'Continuous':\n", - " self.action_head = DiagGaussianHead(config, \n", - " self.device, \n", - " feature_dim, \n", - " self.env_spec, \n", - " min_std=config['agent.min_std'], \n", - " std_style=config['agent.std_style'], \n", - " constant_std=config['agent.constant_std'],\n", - " std_state_dependent=config['agent.std_state_dependent'],\n", - " init_std=config['agent.init_std'])\n", - " \n", - " @property\n", - " def recurrent(self):\n", - " return False\n", - " \n", - " def reset(self, config, **kwargs):\n", - " pass\n", - "\n", - " def __call__(self, x, out_keys=['action'], info={}, **kwargs):\n", - " out = {}\n", - " \n", - " features = self.feature_network(x)\n", - " action_dist = self.action_head(features)\n", - " \n", - " action = action_dist.sample().detach()################################\n", - " out['action'] = action\n", - " \n", - " if 'action_logprob' in out_keys:\n", - " out['action_logprob'] = action_dist.log_prob(action)\n", - " if 'entropy' in out_keys:\n", - " out['entropy'] = action_dist.entropy()\n", - " if 'perplexity' in out_keys:\n", - " out['perplexity'] = action_dist.perplexity()\n", - " \n", - " return out\n", - " \n", - "\n", - "class Agent(BaseAgent):\n", - " r\"\"\"REINFORCE (no baseline). \"\"\"\n", - " def make_modules(self, config):\n", - " self.policy = Policy(config, self.env_spec, self.device)\n", - " \n", - " def prepare(self, config, **kwargs):\n", - " self.total_T = 0\n", - " self.optimizer = optim.Adam(self.policy.parameters(), lr=config['algo.lr'])\n", - " if config['algo.use_lr_scheduler']:\n", - " if 'train.iter' in config:\n", - " self.lr_scheduler = linear_lr_scheduler(self.optimizer, config['train.iter'], 'iteration-based')\n", - " elif 'train.timestep' in config:\n", - " self.lr_scheduler = linear_lr_scheduler(self.optimizer, config['train.timestep']+1, 'timestep-based')\n", - " else:\n", - " self.lr_scheduler = None\n", - " \n", - "\n", - " def reset(self, config, **kwargs):\n", - " pass\n", - "\n", - " def choose_action(self, obs, info={}):\n", - " obs = torch.from_numpy(np.asarray(obs)).float().to(self.device)\n", - " \n", - " out = self.policy(obs, out_keys=['action', 'action_logprob', 'entropy'], info=info)\n", - " \n", - " # sanity check for NaN\n", - " if torch.any(torch.isnan(out['action'])):\n", - " while True:\n", - " print('NaN !')\n", - " if self.env_spec.control_type == 'Continuous':\n", - " out['action'] = constraint_action(self.env_spec, out['action'])\n", - " \n", - " return out\n", - "\n", - " def learn(self, D, info={}):\n", - " batch_policy_loss = []\n", - " batch_entropy_loss = []\n", - " batch_total_loss = []\n", - " \n", - " for trajectory in D:\n", - " logprobs = trajectory.all_info('action_logprob')\n", - " entropies = trajectory.all_info('entropy')\n", - " Qs = trajectory.all_discounted_returns(self.config['algo.gamma'])\n", - " \n", - " # Standardize: encourage/discourage half of performed actions\n", - " if self.config['agent.standardize_Q']:\n", - " Qs = Standardize()(Qs, -1).tolist()\n", - " \n", - " policy_loss = []\n", - " entropy_loss = []\n", - " for logprob, entropy, Q in zip(logprobs, entropies, Qs):\n", - " policy_loss.append(-logprob*Q)\n", - " entropy_loss.append(-entropy)\n", - " \n", - " policy_loss = torch.stack(policy_loss).mean()\n", - " entropy_loss = torch.stack(entropy_loss).mean()\n", - " \n", - " entropy_coef = self.config['agent.entropy_coef']\n", - " total_loss = policy_loss + entropy_coef*entropy_loss\n", - " \n", - " batch_policy_loss.append(policy_loss)\n", - " batch_entropy_loss.append(entropy_loss)\n", - " batch_total_loss.append(total_loss)\n", - " \n", - " policy_loss = torch.stack(batch_policy_loss).mean()\n", - " entropy_loss = torch.stack(batch_entropy_loss).mean()\n", - " loss = torch.stack(batch_total_loss).mean()\n", - " \n", - " self.optimizer.zero_grad()\n", - " loss.backward()\n", - " \n", - " if self.config['agent.max_grad_norm'] is not None:\n", - " clip_grad_norm_(self.parameters(), self.config['agent.max_grad_norm'])\n", - " \n", - " if self.lr_scheduler is not None:\n", - " if self.lr_scheduler.mode == 'iteration-based':\n", - " self.lr_scheduler.step()\n", - " elif self.lr_scheduler.mode == 'timestep-based':\n", - " self.lr_scheduler.step(self.total_T)\n", - "\n", - " self.optimizer.step()\n", - " \n", - " self.total_T += sum([trajectory.T for trajectory in D])\n", - " \n", - " out = {}\n", - " out['loss'] = loss.item()\n", - " out['policy_loss'] = policy_loss.item()\n", - " out['entropy_loss'] = entropy_loss.item()\n", - " if self.lr_scheduler is not None:\n", - " out['current_lr'] = self.lr_scheduler.get_lr()\n", - "\n", - " return out\n", - " \n", - " @property\n", - " def recurrent(self):\n", - " pass\n" - ] + "source": [] }, { "cell_type": "code", @@ -515,7 +348,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.0" + "version": "3.6.6" } }, "nbformat": 4, diff --git a/examples/policy_gradient/vpg/main.py b/examples/policy_gradient/vpg/main.py index 3dd33c96..1171bfed 100644 --- a/examples/policy_gradient/vpg/main.py +++ b/examples/policy_gradient/vpg/main.py @@ -6,5 +6,4 @@ run_experiment(worker_class=ExperimentWorker, master_class=ExperimentMaster, - max_num_worker=None, - daemonic_worker=None) + num_worker=100) diff --git a/examples/policy_gradient/vpg/model.py b/examples/policy_gradient/vpg/model.py index e98e95df..51fb7bc0 100644 --- a/examples/policy_gradient/vpg/model.py +++ b/examples/policy_gradient/vpg/model.py @@ -1,34 +1,128 @@ +import numpy as np +import torch +import torch.optim as optim +import torch.nn as nn +import torch.nn.functional as F +from torch.nn.utils import clip_grad_norm_ -class VPGAgent(BaseAgent): - r"""Vanilla Policy Gradient (VPG) with value network (baseline), no bootstrapping to estimate value function. """ - def __init__(self, config, device, policy, optimizer, **kwargs): - self.policy = policy - self.optimizer = optimizer +from lagom.networks import BaseNetwork +from lagom.networks import make_fc +from lagom.networks import ortho_init +from lagom.networks import linear_lr_scheduler + +from lagom.policies import BasePolicy +from lagom.policies import CategoricalHead +from lagom.policies import DiagGaussianHead +from lagom.policies import constraint_action + +from lagom.value_functions import StateValueHead + +from lagom.transform import Standardize + +from lagom.agents import BaseAgent + + +class MLP(BaseNetwork): + def make_params(self, config): + self.feature_layers = make_fc(self.env_spec.observation_space.flat_dim, config['network.hidden_sizes']) - super().__init__(config, device, **kwargs) + def init_params(self, config): + for layer in self.feature_layers: + ortho_init(layer, nonlinearity='tanh', constant_bias=0.0) - # accumulated trained timesteps - self.total_T = 0 + def reset(self, config, **kwargs): + pass - def choose_action(self, obs, info={}): - if not torch.is_tensor(obs): - obs = np.asarray(obs) - assert obs.ndim >= 2, f'expected at least 2-dim for batched data, got {obs.ndim}' - obs = torch.from_numpy(obs).float().to(self.device) - - if self.policy.recurrent and self.info['reset_rnn_states']: - self.policy.reset_rnn_states(batch_size=obs.size(0)) - self.info['reset_rnn_states'] = False # Done, turn off + def forward(self, x): + for layer in self.feature_layers: + x = torch.tanh(layer(x)) - out_policy = self.policy(obs, - out_keys=['action', 'action_logprob', 'state_value', - 'entropy', 'perplexity'], - info=info) + return x + + +class Policy(BasePolicy): + def make_networks(self, config): + self.feature_network = MLP(config, self.device, env_spec=self.env_spec) + feature_dim = config['network.hidden_sizes'][-1] - return out_policy + if self.env_spec.control_type == 'Discrete': + self.action_head = CategoricalHead(config, self.device, feature_dim, self.env_spec) + elif self.env_spec.control_type == 'Continuous': + self.action_head = DiagGaussianHead(config, + self.device, + feature_dim, + self.env_spec, + min_std=config['agent.min_std'], + std_style=config['agent.std_style'], + constant_std=config['agent.constant_std'], + std_state_dependent=config['agent.std_state_dependent'], + init_std=config['agent.init_std']) + self.V_head = StateValueHead(config, self.device, feature_dim) + + @property + def recurrent(self): + return False + + def reset(self, config, **kwargs): + pass + + def __call__(self, x, out_keys=['action', 'V'], info={}, **kwargs): + out = {} + + features = self.feature_network(x) + action_dist = self.action_head(features) + + action = action_dist.sample().detach()################################ + out['action'] = action + + V = self.V_head(features) + out['V'] = V + + if 'action_logprob' in out_keys: + out['action_logprob'] = action_dist.log_prob(action) + if 'entropy' in out_keys: + out['entropy'] = action_dist.entropy() + if 'perplexity' in out_keys: + out['perplexity'] = action_dist.perplexity() + + return out + + +class Agent(BaseAgent): + r"""Vanilla Policy Gradient (VPG) with value network (baseline), no bootstrapping to estimate value function. """ + def make_modules(self, config): + self.policy = Policy(config, self.env_spec, self.device) + def prepare(self, config, **kwargs): + self.total_T = 0 + self.optimizer = optim.Adam(self.policy.parameters(), lr=config['algo.lr']) + if config['algo.use_lr_scheduler']: + if 'train.iter' in config: + self.lr_scheduler = linear_lr_scheduler(self.optimizer, config['train.iter'], 'iteration-based') + elif 'train.timestep' in config: + self.lr_scheduler = linear_lr_scheduler(self.optimizer, config['train.timestep']+1, 'timestep-based') + else: + self.lr_scheduler = None + + def reset(self, config, **kwargs): + pass + + def choose_action(self, obs, info={}): + obs = torch.from_numpy(np.asarray(obs)).float().to(self.device) + + out = self.policy(obs, out_keys=['action', 'action_logprob', 'V', 'entropy'], info=info) + + # sanity check for NaN + if torch.any(torch.isnan(out['action'])): + while True: + print('NaN !') + if self.env_spec.control_type == 'Continuous': + out['action'] = constraint_action(self.env_spec, out['action']) + + return out + def learn(self, D, info={}): batch_policy_loss = [] batch_entropy_loss = [] @@ -38,23 +132,24 @@ def learn(self, D, info={}): for trajectory in D: logprobs = trajectory.all_info('action_logprob') entropies = trajectory.all_info('entropy') - Qs = trajectory.all_discounted_returns + Qs = trajectory.all_discounted_returns(self.config['algo.gamma']) # Standardize: encourage/discourage half of performed actions if self.config['agent.standardize_Q']: - Qs = Standardize()(Qs).tolist() - - # State values - Vs = trajectory.all_info('V_s') - final_V = trajectory.transitions[-1].V_s_next - final_done = trajectory.transitions[-1].done - - # Advantage estimates + Qs = Standardize()(Qs, -1).tolist() + + Vs = trajectory.all_info('V') + if trajectory.complete: + terminal_state = trajectory.transitions[-1].s_next + terminal_state = torch.tensor([terminal_state]).float().to(self.device) + V_terminal = self.policy(terminal_state)['V'].squeeze(0) + else: + V_terminal = None + As = [Q - V.item() for Q, V in zip(Qs, Vs)] if self.config['agent.standardize_adv']: - As = Standardize()(As).tolist() + As = Standardize()(As, -1).tolist() - # Estimate policy gradient for all time steps and record all losses policy_loss = [] entropy_loss = [] value_loss = [] @@ -62,51 +157,41 @@ def learn(self, D, info={}): policy_loss.append(-logprob*A) entropy_loss.append(-entropy) value_loss.append(F.mse_loss(V, torch.tensor(Q).view_as(V).to(V.device))) - if final_done: # learn terminal state value as zero - value_loss.append(F.mse_loss(final_V, torch.tensor(0.0).view_as(V).to(V.device))) + if V_terminal is not None: + value_loss.append(F.mse_loss(V_terminal, torch.tensor(0.0).view_as(V).to(V.device))) - # Average losses over all time steps policy_loss = torch.stack(policy_loss).mean() entropy_loss = torch.stack(entropy_loss).mean() value_loss = torch.stack(value_loss).mean() - # Calculate total loss entropy_coef = self.config['agent.entropy_coef'] value_coef = self.config['agent.value_coef'] total_loss = policy_loss + value_coef*value_loss + entropy_coef*entropy_loss - # Record all losses batch_policy_loss.append(policy_loss) batch_entropy_loss.append(entropy_loss) batch_value_loss.append(value_loss) batch_total_loss.append(total_loss) - - # Average loss over list of Trajectory + policy_loss = torch.stack(batch_policy_loss).mean() entropy_loss = torch.stack(batch_entropy_loss).mean() value_loss = torch.stack(batch_value_loss).mean() loss = torch.stack(batch_total_loss).mean() - - # Train with estimated policy gradient + self.optimizer.zero_grad() loss.backward() if self.config['agent.max_grad_norm'] is not None: - clip_grad_norm_(parameters=self.policy.network.parameters(), - max_norm=self.config['agent.max_grad_norm'], - norm_type=2) - - if hasattr(self, 'lr_scheduler'): - if 'train.iter' in self.config: # iteration-based + clip_grad_norm_(self.parameters(), self.config['agent.max_grad_norm']) + + if self.lr_scheduler is not None: + if self.lr_scheduler.mode == 'iteration-based': self.lr_scheduler.step() - elif 'train.timestep' in self.config: # timestep-based + elif self.lr_scheduler.mode == 'timestep-based': self.lr_scheduler.step(self.total_T) - else: - raise KeyError('expected `train.iter` or `train.timestep` in config, but got none of them') - + self.optimizer.step() - # Accumulate trained timesteps self.total_T += sum([trajectory.T for trajectory in D]) out = {} @@ -114,13 +199,11 @@ def learn(self, D, info={}): out['policy_loss'] = policy_loss.item() out['entropy_loss'] = entropy_loss.item() out['value_loss'] = value_loss.item() - if hasattr(self, 'lr_scheduler'): + if self.lr_scheduler is not None: out['current_lr'] = self.lr_scheduler.get_lr() return out - def save(self, f): - self.policy.network.save(f) - - def load(self, f): - self.policy.network.load(f) + @property + def recurrent(self): + pass diff --git a/lagom/history/metrics/__init__.py b/lagom/history/metrics/__init__.py new file mode 100644 index 00000000..926bb487 --- /dev/null +++ b/lagom/history/metrics/__init__.py @@ -0,0 +1,8 @@ +from .terminal_states import terminal_state_from_trajectory +from .terminal_states import terminal_state_from_segment + +from .final_states import final_state_from_trajectory +from .final_states import final_state_from_segment + +from .bootstrapped_returns import bootstrapped_returns_from_trajectory +from .bootstrapped_returns import bootstrapped_returns_from_segment diff --git a/lagom/history/metrics/bootstrapped_returns.py b/lagom/history/metrics/bootstrapped_returns.py new file mode 100644 index 00000000..9a3d37b7 --- /dev/null +++ b/lagom/history/metrics/bootstrapped_returns.py @@ -0,0 +1,79 @@ +import numpy as np +import torch + +from lagom.history import Trajectory +from lagom.history import Segment + +from lagom.transform import ExpFactorCumSum + + +def bootstrapped_returns_from_trajectory(trajectory, V_last, gamma=1.0): + r"""Return a list of (discounted) accumulated returns with bootstrapping for all + time steps, from a trajectory. + + Formally, suppose we have all rewards :math:`(r_1, \dots, r_T)`, it computes + + .. math:: + Q_t = r_t + \gamma r_{t+1} + \dots + \gamma^{T - t} r_T + \gamma^{T - t + 1} V(s_{T+1}) + + .. note:: + + The state value for terminal state is set as zero ! + + Args: + trajectory (Trajectory): a trajectory + V_last (object): the value of the final state in the trajectory + gamma (float): discounted factor + + Returns + ------- + out : list + a list of (discounted) bootstrapped returns + """ + assert isinstance(trajectory, Trajectory) + + if torch.is_tensor(V_last): + V_last = V_last.item() + if isinstance(V_last, np.ndarray): + V_last = V_last.item() + + if trajectory.complete: + V_last = 0.0 + + out = ExpFactorCumSum(gamma)(trajectory.all_r + [V_last]) + out = out[:-1] # last one is just state value itself + + return out + + +def bootstrapped_returns_from_segment(segment, all_V_last, gamma=1.0): + r"""Return a list of (discounted) accumulated returns with bootstrapping for all + time steps, from a segment. + + Formally, suppose we have all rewards :math:`(r_1, \dots, r_T)`, it computes + + .. math:: + Q_t = r_t + \gamma r_{t+1} + \dots + \gamma^{T - t} r_T + \gamma^{T - t + 1} V(s_{T+1}) + + .. note:: + + The state value for terminal state is set as zero ! + + Args: + segment (Segment): a segment + all_V_last (object): the value of all final states for each trajectory in the segment. + gamma (float): discounted factor + + Returns + ------- + out : list + a list of (discounted) bootstrapped returns + """ + assert isinstance(segment, Segment) + + assert len(segment.trajectories) == len(all_V_last) + out = [] + for trajectory, V_last in zip(segment.trajectories, all_V_last): + out += bootstrapped_returns_from_trajectory(trajectory, V_last, gamma) + + return out diff --git a/lagom/history/metrics/final_states.py b/lagom/history/metrics/final_states.py new file mode 100644 index 00000000..e73866fc --- /dev/null +++ b/lagom/history/metrics/final_states.py @@ -0,0 +1,37 @@ +from lagom.history import Trajectory +from lagom.history import Segment + + +def final_state_from_trajectory(trajectory): + r"""Return the final state of a trajectory. + + Args: + trajectory (Trajectory): a trajectory + + Returns + ------- + out : object + """ + assert isinstance(trajectory, Trajectory) + return trajectory.transitions[-1].s_next + + +def final_state_from_segment(segment): + r"""Return a list of final states of a segment. + + It collects the final state from each trajectory stored in the segment. + + Args: + segment (Segment): a segment + + Returns + ------- + out : object + """ + assert isinstance(segment, Segment) + + final_states = [] + for trajectory in segment.trajectories: + final_states.append(final_state_from_trajectory(trajectory)) + + return final_states diff --git a/lagom/history/metrics/terminal_states.py b/lagom/history/metrics/terminal_states.py new file mode 100644 index 00000000..81affa63 --- /dev/null +++ b/lagom/history/metrics/terminal_states.py @@ -0,0 +1,47 @@ +from lagom.history import Trajectory +from lagom.history import Segment + + +def terminal_state_from_trajectory(trajectory): + r"""Return the terminal state of a trajectory if available. + + If the trajectory does not have terminal state, then an ``None`` is returned. + + Args: + trajectory (Trajectory): a trajectory + + Returns + ------- + out : object + """ + assert isinstance(trajectory, Trajectory) + + if trajectory.complete: + return trajectory.transitions[-1].s_next + else: + return None + + +def terminal_state_from_segment(segment): + r"""Return a list of terminal states of a segment if available. + + It collects terminal state from each trajectory stored in the segment. + + If the segment does not have terminal state, then an empty list if returned. + + Args: + segment (Segment): a segment + + Returns + ------- + out : object + """ + assert isinstance(segment, Segment) + + terminal_states = [] + + for trajectory in segment.trajectories: + if trajectory.complete: + terminal_states.append(terminal_state_from_trajectory(trajectory)) + + return terminal_states diff --git a/lagom/runner/segment_runner.py b/lagom/runner/segment_runner.py index 0d7bfd26..076250fa 100644 --- a/lagom/runner/segment_runner.py +++ b/lagom/runner/segment_runner.py @@ -120,6 +120,7 @@ def __call__(self, T, reset=False): # Record additional information [transition.add_info(key, val[i]) for key, val in out_agent.items()] + [transition.add_info(key, val) for key, val in info[i].items()] segment.add_transition(transition) diff --git a/lagom/runner/trajectory_runner.py b/lagom/runner/trajectory_runner.py index a6675457..1126f599 100644 --- a/lagom/runner/trajectory_runner.py +++ b/lagom/runner/trajectory_runner.py @@ -88,6 +88,7 @@ def __call__(self, T): # Record additional information [transition.add_info(key, val[i]) for key, val in out_agent.items()] + [transition.add_info(key, val) for key, val in info[i].items()] trajectory.add_transition(transition) diff --git a/test/test_history.py b/test/test_history.py index a07a5919..6202050d 100644 --- a/test/test_history.py +++ b/test/test_history.py @@ -8,6 +8,13 @@ from lagom.history import Trajectory from lagom.history import Segment +from lagom.history.metrics import terminal_state_from_trajectory +from lagom.history.metrics import terminal_state_from_segment +from lagom.history.metrics import final_state_from_trajectory +from lagom.history.metrics import final_state_from_segment +from lagom.history.metrics import bootstrapped_returns_from_trajectory +from lagom.history.metrics import bootstrapped_returns_from_segment + def test_transition(): transition = Transition(s=1.2, a=2.0, r=-1.0, s_next=1.5, done=True) @@ -312,3 +319,140 @@ def test_segment(): del transition3 del transition4 del all_info + + +def test_terminal_state_from_trajectory(): + t = Trajectory() + t.add_transition(Transition(1.0, 10, 0.1, 2.0, False)) + t.add_transition(Transition(2.0, 20, 0.2, 3.0, False)) + t.add_transition(Transition(3.0, 30, 0.3, 4.0, True)) + + assert terminal_state_from_trajectory(t) == 4.0 + + t = Trajectory() + t.add_transition(Transition(1.0, 10, 0.1, 2.0, False)) + t.add_transition(Transition(2.0, 20, 0.2, 3.0, False)) + t.add_transition(Transition(3.0, 30, 0.3, 4.0, False)) + + assert terminal_state_from_trajectory(t) is None + + with pytest.raises(AssertionError): + terminal_state_from_segment(t) + +def test_terminal_state_from_segment(): + s = Segment() + s.add_transition(Transition(1.0, 10, 0.1, 2.0, False)) + s.add_transition(Transition(2.0, 20, 0.2, 3.0, False)) + s.add_transition(Transition(3.0, 30, 0.3, 4.0, True)) + s.add_transition(Transition(5.0, 50, 0.5, 6.0, False)) + s.add_transition(Transition(6.0, 60, 0.6, 7.0, True)) + assert terminal_state_from_segment(s) == [4.0, 7.0] + + s = Segment() + s.add_transition(Transition(1.0, 10, 0.1, 2.0, False)) + s.add_transition(Transition(2.0, 20, 0.2, 3.0, False)) + s.add_transition(Transition(3.0, 30, 0.3, 4.0, True)) + s.add_transition(Transition(5.0, 50, 0.5, 6.0, False)) + s.add_transition(Transition(6.0, 60, 0.6, 7.0, False)) + assert terminal_state_from_segment(s) == [4.0] + + with pytest.raises(AssertionError): + terminal_state_from_trajectory(s) + + +def test_final_state_from_trajectory(): + t = Trajectory() + t.add_transition(Transition(1.0, 10, 0.1, 2.0, False)) + t.add_transition(Transition(2.0, 20, 0.2, 3.0, False)) + t.add_transition(Transition(3.0, 30, 0.3, 4.0, True)) + + assert final_state_from_trajectory(t) == 4.0 + + t = Trajectory() + t.add_transition(Transition(1.0, 10, 0.1, 2.0, False)) + t.add_transition(Transition(2.0, 20, 0.2, 3.0, False)) + t.add_transition(Transition(3.0, 30, 0.3, 4.0, False)) + + assert final_state_from_trajectory(t) == 4.0 + + with pytest.raises(AssertionError): + final_state_from_segment(t) + + +def test_final_state_from_segment(): + s = Segment() + s.add_transition(Transition(1.0, 10, 0.1, 2.0, False)) + s.add_transition(Transition(2.0, 20, 0.2, 3.0, False)) + s.add_transition(Transition(3.0, 30, 0.3, 4.0, True)) + s.add_transition(Transition(5.0, 50, 0.5, 6.0, False)) + s.add_transition(Transition(6.0, 60, 0.6, 7.0, True)) + assert final_state_from_segment(s) == [4.0, 7.0] + + s = Segment() + s.add_transition(Transition(1.0, 10, 0.1, 2.0, False)) + s.add_transition(Transition(2.0, 20, 0.2, 3.0, False)) + s.add_transition(Transition(3.0, 30, 0.3, 4.0, True)) + s.add_transition(Transition(5.0, 50, 0.5, 6.0, False)) + s.add_transition(Transition(6.0, 60, 0.6, 7.0, False)) + assert final_state_from_segment(s) == [4.0, 7.0] + + with pytest.raises(AssertionError): + final_state_from_trajectory(s) + + +def test_bootstrapped_returns_from_trajectory(): + t = Trajectory() + t.add_transition(Transition(1.0, 10, 0.1, 2.0, False)) + t.add_transition(Transition(2.0, 20, 0.2, 3.0, False)) + t.add_transition(Transition(3.0, 30, 0.3, 4.0, True)) + V_last = 100 + + out = bootstrapped_returns_from_trajectory(t, V_last, 1.0) + assert np.allclose(out, [0.6, 0.5, 0.3]) + out = bootstrapped_returns_from_trajectory(t, V_last, 0.1) + assert np.allclose(out, [0.123, 0.23, 0.3]) + + t = Trajectory() + t.add_transition(Transition(1.0, 10, 0.1, 2.0, False)) + t.add_transition(Transition(2.0, 20, 0.2, 3.0, False)) + t.add_transition(Transition(3.0, 30, 0.3, 4.0, False)) + V_last = 100 + + out = bootstrapped_returns_from_trajectory(t, V_last, 1.0) + assert np.allclose(out, [100.6, 100.5, 100.3]) + out = bootstrapped_returns_from_trajectory(t, V_last, 0.1) + assert np.allclose(out, [0.223, 1.23, 10.3]) + + with pytest.raises(AssertionError): + bootstrapped_returns_from_segment(t, V_last, 1.0) + + +def test_bootstrapped_returns_from_segment(): + s = Segment() + s.add_transition(Transition(1.0, 10, 0.1, 2.0, False)) + s.add_transition(Transition(2.0, 20, 0.2, 3.0, False)) + s.add_transition(Transition(3.0, 30, 0.3, 4.0, True)) + s.add_transition(Transition(5.0, 50, 0.5, 6.0, False)) + s.add_transition(Transition(6.0, 60, 0.6, 7.0, True)) + all_V_last = [50, 100] + + out = bootstrapped_returns_from_segment(s, all_V_last, 1.0) + assert np.allclose(out, [0.6, 0.5, 0.3, 1.1, 0.6]) + out = bootstrapped_returns_from_segment(s, all_V_last, 0.1) + assert np.allclose(out, [0.123, 0.23, 0.3, 0.56, 0.6]) + + s = Segment() + s.add_transition(Transition(1.0, 10, 0.1, 2.0, False)) + s.add_transition(Transition(2.0, 20, 0.2, 3.0, False)) + s.add_transition(Transition(3.0, 30, 0.3, 4.0, True)) + s.add_transition(Transition(5.0, 50, 0.5, 6.0, False)) + s.add_transition(Transition(6.0, 60, 0.6, 7.0, False)) + all_V_last = [50, 100] + + out = bootstrapped_returns_from_segment(s, all_V_last, 1.0) + assert np.allclose(out, [0.6, 0.5, 0.3, 101.1, 100.6]) + out = bootstrapped_returns_from_segment(s, all_V_last, 0.1) + assert np.allclose(out, [0.123, 0.23, 0.3, 1.56, 10.6]) + + with pytest.raises(AssertionError): + bootstrapped_returns_from_trajectory(s, all_V_last, 1.0)