diff --git a/examples/policy_gradient/a2c/algo.py b/examples/policy_gradient/a2c/algo.py index ac78f7fd..0212eb86 100644 --- a/examples/policy_gradient/a2c/algo.py +++ b/examples/policy_gradient/a2c/algo.py @@ -5,135 +5,147 @@ import numpy as np import torch -import torch.optim as optim import torch.nn as nn import torch.nn.functional as F +import torch.optim as optim from lagom import set_global_seeds from lagom import BaseAlgorithm -from lagom.envs import EnvSpec +from lagom import pickle_dump -from lagom.envs import make_envs from lagom.envs import make_gym_env +from lagom.envs import make_vec_env +from lagom.envs import EnvSpec from lagom.envs.vec_env import SerialVecEnv -from lagom.envs.vec_env import StandardizeVecEnv +from lagom.envs.vec_env import VecStandardize +from lagom.core.policies import CategoricalPolicy +from lagom.core.policies import GaussianPolicy + +from lagom.runner import TrajectoryRunner from lagom.runner import SegmentRunner from lagom.agents import A2CAgent from engine import Engine -from policy import CategoricalMLP -from policy import CategoricalPolicy -from policy import GaussianMLP -from policy import GaussianPolicy +from policy import Network class Algorithm(BaseAlgorithm): - def __call__(self, config): - # Set random seeds: PyTorch, numpy.random, random - set_global_seeds(seed=config['seed']) + def __call__(self, config, seed, device_str): + # Set random seeds + set_global_seeds(seed) + # Create device + device = torch.device(device_str) + # Use log dir for current job (run_experiment) + logdir = Path(config['log.dir']) / str(config['ID']) / str(seed) - # Create an VecEnv environment - list_make_env = make_envs(make_env=make_gym_env, - env_id=config['env:id'], - num_env=config['train:N'], - init_seed=config['seed']) - env = SerialVecEnv(list_make_env) - # Wrapper to standardize observation and reward from running average - if config['env:normalize']: - env = StandardizeVecEnv(venv=env, - use_obs=True, - use_reward=True, - clip_obs=10., - clip_reward=10., - gamma=0.99, - eps=1e-8) - # Create environment specification + # Make environment (VecEnv) for training and evaluating + env = make_vec_env(vec_env_class=SerialVecEnv, + make_env=make_gym_env, + env_id=config['env.id'], + num_env=1, + init_seed=seed) + eval_env = make_vec_env(vec_env_class=SerialVecEnv, + make_env=make_gym_env, + env_id=config['env.id'], + num_env=1, + init_seed=seed) + if config['env.standardize']: # wrap with VecStandardize for running averages of observation and rewards + env = VecStandardize(venv=env, + use_obs=True, + use_reward=True, + clip_obs=10., + clip_reward=10., + gamma=0.99, + eps=1e-8) + eval_env = VecStandardize(venv=eval_env, # remember to synchronize running averages during evaluation !!! + use_obs=True, + use_reward=False, # do not process rewards, no training + clip_obs=env.clip_obs, + clip_reward=env.clip_reward, + gamma=env.gamma, + eps=env.eps, + constant_obs_mean=env.obs_runningavg.mu, # use current running average as constant + constant_obs_std=env.obs_runningavg.sigma) env_spec = EnvSpec(env) - # Create device object, note that in BaseExperimentWorker already assigns a specific GPU for this task - device = torch.device(f'cuda:{torch.cuda.current_device()}' if config['cuda'] else 'cpu') - # Create policy + network = Network(config=config, env_spec=env_spec) if env_spec.control_type == 'Discrete': - network = CategoricalMLP(config=config, env_spec=env_spec).to(device) - policy = CategoricalPolicy(network=network, - env_spec=env_spec, - config=config) + policy = CategoricalPolicy(config=config, network=network, env_spec=env_spec, learn_V=True) elif env_spec.control_type == 'Continuous': - network = GaussianMLP(config=config, env_spec=env_spec).to(device) - policy = GaussianPolicy(network=network, + policy = GaussianPolicy(config=config, + network=network, env_spec=env_spec, - config=config, - min_std=config['agent:min_std'], - std_style=config['agent:std_style'], - constant_std=config['agent:constant_std']) - - # Create optimizer - optimizer = optim.Adam(policy.network.parameters(), lr=config['algo:lr']) - # Create learning rate scheduler - if config['algo:use_lr_scheduler']: - # Define max number of lr decay - if 'train:iter' in config: # iteration-based training - max_epoch = config['train:iter'] - elif 'train:timestep' in config: # timestep-based training - max_epoch = config['train:timestep'] + 1 # plus 1 avoid having 0.0 lr in final iteration + learn_V=True, + min_std=config['agent.min_std'], + std_style=config['agent.std_style'], + constant_std=config['agent.constant_std'], + std_state_dependent=config['agent.std_state_dependent'], + init_std=config['agent.init_std']) + network = network.to(device) + + # Create optimizer and learning rate scheduler + optimizer = optim.Adam(policy.network.parameters(), lr=config['algo.lr']) + if config['algo.use_lr_scheduler']: + if 'train.iter' in config: # iteration-based training + max_epoch = config['train.iter'] + elif 'train.timestep' in config: # timestep-based training + max_epoch = config['train.timestep'] + 1 # +1 to avoid 0.0 lr in final iteration lambda_f = lambda epoch: 1 - epoch/max_epoch # decay learning rate for each training epoch lr_scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_f) # Create agent kwargs = {'device': device} - if config['algo:use_lr_scheduler']: + if config['algo.use_lr_scheduler']: kwargs['lr_scheduler'] = lr_scheduler - agent = A2CAgent(policy=policy, + agent = A2CAgent(config=config, + policy=policy, optimizer=optimizer, - config=config, **kwargs) # Create runner runner = SegmentRunner(agent=agent, env=env, - gamma=config['algo:gamma']) + gamma=config['algo.gamma']) + eval_runner = TrajectoryRunner(agent=agent, + env=eval_env, + gamma=1.0) # Create engine engine = Engine(agent=agent, runner=runner, config=config, - logger=None) + eval_runner=eval_runner) # Training and evaluation train_logs = [] eval_logs = [] - for i in count(): # successively increment iteration - # Terminate until condition is met - if 'train:iter' in config and i >= config['train:iter']: # enough iteration, terminate + for i in count(): # incremental iteration + if 'train.iter' in config and i >= config['train.iter']: # enough iterations break - elif 'train:timestep' in config and agent.accumulated_trained_timesteps >= config['train:timestep']: + elif 'train.timestep' in config and agent.total_T >= config['train.timestep']: # enough timesteps break - - # Do training - train_output = engine.train(i) - # Logging and evaluation - if i == 0 or (i+1) % config['log:interval'] == 0: - # Log training and record the loggings - train_logger = engine.log_train(train_output) - train_logs.append(train_logger.logs) - # Log evaluation and record the loggings - with torch.no_grad(): # no need to have gradient, save memory - eval_output = engine.eval(i) - eval_logger = engine.log_eval(eval_output) - eval_logs.append(eval_logger.logs) - - # Save the logging periodically - # This is good to avoid saving very large file at once, because the program might get stuck - # The file name is augmented with current iteration - np.save(Path(config['log:dir']) / str(config['ID']) / f'train:{i}', train_logs) - np.save(Path(config['log:dir']) / str(config['ID']) / f'eval:{i}', eval_logs) - # Clear the logging list - train_logs.clear() - eval_logs.clear() + # train and evaluation + train_output = engine.train(n=i) + + # logging + if i == 0 or (i+1) % config['log.record_interval'] == 0 or (i+1) % config['log.print_interval'] == 0: + train_log = engine.log_train(train_output) + + with torch.no_grad(): # disable grad, save memory + eval_output = engine.eval(n=i) + eval_log = engine.log_eval(eval_output) + + if i == 0 or (i+1) % config['log.record_interval'] == 0: # record loggings + train_logs.append(train_log) + eval_logs.append(eval_log) + + # Save all loggings + pickle_dump(obj=train_logs, f=logdir/'train_logs', ext='.pkl') + pickle_dump(obj=eval_logs, f=logdir/'eval_logs', ext='.pkl') return None diff --git a/examples/policy_gradient/a2c/engine.py b/examples/policy_gradient/a2c/engine.py index 63a33ea8..7632f3c4 100644 --- a/examples/policy_gradient/a2c/engine.py +++ b/examples/policy_gradient/a2c/engine.py @@ -3,21 +3,24 @@ import torch from lagom import Logger +from lagom import color_str + from lagom.engine import BaseEngine + from lagom.envs import make_gym_env from lagom.envs import make_envs from lagom.envs.vec_env import SerialVecEnv -from lagom.envs.vec_env import StandardizeVecEnv +from lagom.envs.vec_env import VecStandardize + from lagom.runner import TrajectoryRunner class Engine(BaseEngine): def train(self, n): - # Set network as training mode - self.agent.policy.network.train() + self.agent.policy.network.train() # set to train mode # Collect a list of segments - D = self.runner(T=self.config['train:T']) + D = self.runner(T=self.config['train.T']) # Train agent with collected data out_agent = self.agent.learn(D) @@ -29,7 +32,7 @@ def train(self, n): return train_output - def log_train(self, train_output): + def log_train(self, train_output, **kwargs): # Create training logger logger = Logger(name='train_logger') @@ -38,68 +41,48 @@ def log_train(self, train_output): out_agent = train_output['out_agent'] n = train_output['n'] - # Loggings - # Use item() for tensor to save memory - logger.log(key='train_iteration', val=n+1) # iteration starts from 1 - if self.config['algo:use_lr_scheduler']: - logger.log(key='current_lr', val=out_agent['current_lr']) + # Loggings: use item() to save memory + logger.log('train_iteration', n+1) # iteration starts from 1 + if self.config['algo.use_lr_scheduler']: + logger.log('current_lr', out_agent['current_lr']) - logger.log(key='loss', val=out_agent['loss'].item()) - policy_loss = torch.stack(out_agent['batch_policy_loss']).mean().item() - logger.log(key='policy_loss', val=policy_loss) - entropy_loss = torch.stack(out_agent['batch_entropy_loss']).mean().item() - logger.log(key='policy_entropy', val=-entropy_loss) # negation of entropy loss - value_loss = torch.stack(out_agent['batch_value_loss']).mean().item() - logger.log(key='value_loss', val=value_loss) - - # Get some data from segment list + logger.log('loss', out_agent['loss']) + logger.log('policy_loss', out_agent['policy_loss']) + logger.log('policy_entropy', -out_agent['entropy_loss']) # negate entropy loss is entropy + logger.log('value_loss', out_agent['value_loss']) + + # Log something about segments all_immediate_reward = [segment.all_r for segment in D] num_timesteps = sum([segment.T for segment in D]) - # Log more information - logger.log(key='num_segments', val=sum([len(segment.split_transitions) for segment in D])) - logger.log(key='num_timesteps', val=num_timesteps) - logger.log(key='accumulated_trained_timesteps', val=self.agent.accumulated_trained_timesteps) - logger.log(key='average_immediate_reward', val=np.mean(all_immediate_reward)) - logger.log(key='std_immediate_reward', val=np.std(all_immediate_reward)) - logger.log(key='min_immediate_reward', val=np.min(all_immediate_reward)) - logger.log(key='max_immediate_reward', val=np.max(all_immediate_reward)) + logger.log('num_segments', len(D)) + logger.log('num_subsegments', sum([len(segment.trajectories) for segment in D])) + logger.log('num_timesteps', num_timesteps) + logger.log('accumulated_trained_timesteps', self.agent.total_T) + logger.log('average_immediate_reward', np.mean(all_immediate_reward)) + logger.log('std_immediate_reward', np.std(all_immediate_reward)) + logger.log('min_immediate_reward', np.min(all_immediate_reward)) + logger.log('max_immediate_reward', np.max(all_immediate_reward)) + + # Dump loggings + if n == 0 or (n+1) % self.config['log.print_interval'] == 0: + print('-'*50) + logger.dump(keys=None, index=None, indent=0) + print('-'*50) + + return logger.logs - # Dump the loggings - print('-'*50) - logger.dump(keys=None, index=None, indent=0) - print('-'*50) + def eval(self, n): + self.agent.policy.network.eval() # set to evaluation mode - return logger + # synchronize running average if environment wrapped by VecStandardize + if self.config['env.standardize']: + self.eval_runner.env.constant_obs_mean = self.runner.env.obs_runningavg.mu + self.eval_runner.env.constant_obs_std = self.runner.env.obs_runningavg.sigma - def eval(self, n): - # Set network as evaluation mode - self.agent.policy.network.eval() - - # Create a new instance of VecEnv envrionment - list_make_env = make_envs(make_env=make_gym_env, - env_id=self.config['env:id'], - num_env=1, - init_seed=self.config['seed']) - env = SerialVecEnv(list_make_env) - # Wrapper to standardize observation from training scaling of mean and standard deviation - if self.config['env:normalize']: - env = StandardizeVecEnv(venv=env, - use_obs=True, - use_reward=False, # do not standardize reward, use original - clip_obs=self.runner.env.clip_obs, - eps=self.runner.env.eps, - constant_obs_mean=self.runner.env.obs_runningavg.mu, - constant_obs_std=self.runner.env.obs_runningavg.sigma) - - # Create a TrajectoryRunner - runner = TrajectoryRunner(agent=self.agent, - env=env, - gamma=self.config['algo:gamma']) - # Evaluate the agent for a set of trajectories - # Retrieve pre-defined maximum episode timesteps in the environment - T = env.T[0] # take first one because of VecEnv - D = runner(N=self.config['eval:N'], T=T) + # Collect a list of trajectories + T = self.eval_runner.env.T + D = self.eval_runner(N=self.config['eval.N'], T=T) # Return evaluation output eval_output = {} @@ -109,7 +92,7 @@ def eval(self, n): return eval_output - def log_eval(self, eval_output): + def log_eval(self, eval_output, **kwargs): # Create evaluation logger logger = Logger(name='eval_logger') @@ -118,26 +101,26 @@ def log_eval(self, eval_output): n = eval_output['n'] T = eval_output['T'] - # Compute some metrics + # Loggings: use item() to save memory + # Log something about trajectories batch_returns = [sum(trajectory.all_r) for trajectory in D] batch_T = [trajectory.T for trajectory in D] - # Loggings - # Use item() for tensor to save memory - logger.log(key='evaluation_iteration', val=n+1) - logger.log(key='num_trajectories', val=len(D)) - logger.log(key='max_allowed_horizon', val=T) - logger.log(key='average_horizon', val=np.mean(batch_T)) - logger.log(key='num_timesteps', val=np.sum(batch_T)) - logger.log(key='accumulated_trained_timesteps', val=self.agent.accumulated_trained_timesteps) - logger.log(key='average_return', val=np.mean(batch_returns)) - logger.log(key='std_return', val=np.std(batch_returns)) - logger.log(key='min_return', val=np.min(batch_returns)) - logger.log(key='max_return', val=np.max(batch_returns)) - - # Dump the loggings - print('-'*50) - logger.dump(keys=None, index=None, indent=0) - print('-'*50) - - return logger + logger.log('evaluation_iteration', n+1) + logger.log('num_trajectories', len(D)) + logger.log('max_allowed_horizon', T) + logger.log('average_horizon', np.mean(batch_T)) + logger.log('num_timesteps', np.sum(batch_T)) + logger.log('accumulated_trained_timesteps', self.agent.total_T) + logger.log('average_return', np.mean(batch_returns)) + logger.log('std_return', np.std(batch_returns)) + logger.log('min_return', np.min(batch_returns)) + logger.log('max_return', np.max(batch_returns)) + + # Dump loggings + if n == 0 or (n+1) % self.config['log.print_interval'] == 0: + print(color_str('+'*50, 'yellow', 'bold')) + logger.dump(keys=None, index=None, indent=0) + print(color_str('+'*50, 'yellow', 'bold')) + + return logger.logs diff --git a/examples/policy_gradient/a2c/experiment.py b/examples/policy_gradient/a2c/experiment.py index 07d3719c..b508ee3e 100644 --- a/examples/policy_gradient/a2c/experiment.py +++ b/examples/policy_gradient/a2c/experiment.py @@ -1,9 +1,9 @@ -from algo import Algorithm - -from lagom.experiment import Config +from lagom.experiment import Configurator from lagom.experiment import BaseExperimentWorker from lagom.experiment import BaseExperimentMaster +from algo import Algorithm + class ExperimentWorker(BaseExperimentWorker): def make_algo(self): @@ -13,88 +13,48 @@ def make_algo(self): class ExperimentMaster(BaseExperimentMaster): - def process_algo_result(self, config, result): - assert result is None - def make_configs(self): - config = Config() + configurator = Configurator('grid') + + configurator.fixed('cuda', True) # whether to use GPU - ########################## - # General configurations # - ########################## - # Whether to use GPU - config.add_item(name='cuda', val=True) - # Random seeds: generated by `np.random.randint(0, np.iinfo(np.int32).max, 5)` - config.add_grid(name='seed', val=[144682090, 591442434, 1746958036, 338375070, 689208529]) + configurator.fixed('env.id', 'HalfCheetah-v2') + configurator.fixed('env.standardize', True) # whether to use VecStandardize - ############################ - # Algorithm configurations # - ############################ - # Learning rate - config.add_item(name='algo:lr', val=1e-3) - # Discount factor - config.add_item(name='algo:gamma', val=0.99) - # Whether to use learning rate scheduler - config.add_item(name='algo:use_lr_scheduler', val=False) + configurator.fixed('network.hidden_sizes', [64, 64]) - ############################## - # Environment configurations # - ############################## - # Environment ID - # Note that better to run environment only one by one - # because of specific settings, e.g. train:T, log-interval for fair benchmark curve - config.add_item(name='env:id', val='CartPole-v1') - # Flag for continuous or discrete control - continuous = False - # Whether to standardize the observation and reward by running average - config.add_item(name='env:normalize', val=False) + configurator.fixed('algo.lr', 1e-3) + configurator.fixed('algo.use_lr_scheduler', True) + configurator.fixed('algo.gamma', 0.99) - ######################### - # Engine configurations # - ######################### - # Max training timesteps - # Alternative: 'train:iter' for training iterations - config.add_item(name='train:timestep', val=1e6) # recommended: 1e-6, i.e. 1M timesteps - # Number of Segment per training iteration - config.add_item(name='train:N', val=16) - # Number of timesteps per Segment - config.add_item(name='train:T', val=5) + configurator.fixed('agent.standardize_Q', False) # whether to standardize discounted returns + configurator.fixed('agent.max_grad_norm', 0.5) # grad clipping, set None to turn off + configurator.fixed('agent.entropy_coef', 0.01) + configurator.fixed('agent.value_coef', 0.5) + # only for continuous control + configurator.fixed('agent.min_std', 1e-6) # min threshould for std, avoid numerical instability + configurator.fixed('agent.std_style', 'exp') # std parameterization, 'exp' or 'softplus' + configurator.fixed('agent.constant_std', None) # constant std, set None to learn it + configurator.fixed('agent.std_state_dependent', False) # whether to learn std with state dependency + configurator.fixed('agent.init_std', 0.5) # initial std for state-independent std - # Evaluation: number of episodes to obtain average episode reward - # We do not specify T, because the Engine.eval will automatically use env.T for complete episode - config.add_item(name='eval:N', val=10) + configurator.fixed('train.timestep', 1e6) # either 'train.iter' or 'train.timestep' + configurator.fixed('train.N', 16) # number of segments per training iteration + configurator.fixed('train.T', 5) # fixed-length segment rolling + configurator.fixed('eval.N', 10) # number of episodes to evaluate, do not specify T for complete episode - ####################### - # Agent configuration # - ####################### - # Whether to standardize the discounted returns - config.add_item(name='agent:standardize', val=False) - # Gradient clipping with max gradient norm - config.add_item(name='agent:max_grad_norm', val=0.5) - # Coefficient for policy entropy loss - config.add_grid(name='agent:entropy_coef', val=[0.01, 0.1, 0.5, 1.0]) - # Coefficient for value loss - config.add_item(name='agent:value_coef', val=0.5) - # For Gaussian policy - if continuous: - # Min std threshould, avoid numerical instability - config.add_item(name='agent:min_std', val=1e-6) - # Use constant std; If use trainable std, put None - config.add_item(name='agent:constant_std', val=None) - # Whether to have state dependence for learning std - config.add_item(name='agent:std_state_dependent', val=False) - # Std parameterization: 'exp' or 'softplus' - config.add_item(name='agent:std_style', val='exp') + configurator.fixed('log.record_interval', 100) # interval to record the logging + configurator.fixed('log.print_interval', 1000) # interval to print the logging to screen + configurator.fixed('log.dir', 'logs') # logging directory - ########################## - # Logging configurations # - ########################## - # Periodic interval to log and save information - config.add_item(name='log:interval', val=100) - # Directory to save loggings - config.add_item(name='log:dir', val=f'logs/entropy_coef/{config.config_settings["env:id"][0]}') + list_config = configurator.make_configs() - # Auto-generate list of all possible configurations - configs = config.make_configs() + return list_config + + def make_seeds(self): + list_seed = [209652396, 398764591, 924231285, 1478610112, 441365315] - return configs + return list_seed + + def process_algo_result(self, config, seed, result): + assert result is None diff --git a/examples/policy_gradient/a2c/main.py b/examples/policy_gradient/a2c/main.py index ec48721f..3dd33c96 100644 --- a/examples/policy_gradient/a2c/main.py +++ b/examples/policy_gradient/a2c/main.py @@ -6,5 +6,5 @@ run_experiment(worker_class=ExperimentWorker, master_class=ExperimentMaster, - max_num_worker=50, + max_num_worker=None, daemonic_worker=None) diff --git a/examples/policy_gradient/a2c/policy.py b/examples/policy_gradient/a2c/policy.py index 3165f10e..37b831df 100644 --- a/examples/policy_gradient/a2c/policy.py +++ b/examples/policy_gradient/a2c/policy.py @@ -4,137 +4,23 @@ import torch.nn as nn import torch.nn.functional as F -from lagom.core.networks import BaseMLP -from lagom.core.policies import BaseCategoricalPolicy -from lagom.core.policies import BaseGaussianPolicy +from lagom.core.networks import BaseNetwork +from lagom.core.networks import make_fc +from lagom.core.networks import ortho_init -class CategoricalMLP(BaseMLP): +class Network(BaseNetwork): def make_params(self, config): - self.fc1 = nn.Linear(in_features=self.env_spec.observation_space.flat_dim, out_features=64) - self.fc2 = nn.Linear(in_features=64, out_features=64) - - self.action_head = nn.Linear(in_features=64, out_features=self.env_spec.action_space.flat_dim) - self.value_head = nn.Linear(in_features=64, out_features=1) + self.layers = make_fc(input_dim=self.env_spec.observation_space.flat_dim, + hidden_sizes=config['network.hidden_sizes']) + self.last_feature_dim = config['network.hidden_sizes'][-1] def init_params(self, config): - gain = nn.init.calculate_gain(nonlinearity='tanh') - - nn.init.orthogonal_(self.fc1.weight, gain=gain) - nn.init.constant_(self.fc1.bias, 0.0) - - nn.init.orthogonal_(self.fc2.weight, gain=gain) - nn.init.constant_(self.fc2.bias, 0.0) - - nn.init.orthogonal_(self.action_head.weight, gain=0.01) # Smaller scale for action head - nn.init.constant_(self.action_head.bias, 0.0) + for layer in self.layers: + ortho_init(layer, nonlinearity='tanh', constant_bias=0.0) - nn.init.orthogonal_(self.value_head.weight, gain=1.0) # no nonlinearity - nn.init.constant_(self.value_head.bias, 0.0) - def forward(self, x): - # Output dictionary - network_out = {} - - # Flatten the input - x = x.flatten(start_dim=1) - - # Forward pass through feature layers - x = torch.tanh(self.fc1(x)) - x = torch.tanh(self.fc2(x)) - - # Forward pass through action layers and record the output - action_scores = self.action_head(x) - network_out['action_scores'] = action_scores - - # Forward pass through value layer and record the output - state_value = self.value_head(x) - network_out['state_value'] = state_value - - return network_out + for layer in self.layers: + x = torch.tanh(layer(x)) - -class GaussianMLP(BaseMLP): - def make_params(self, config): - self.fc1 = nn.Linear(in_features=self.env_spec.observation_space.flat_dim, out_features=64) - self.fc2 = nn.Linear(in_features=64, out_features=64) - - self.mean_head = nn.Linear(in_features=64, out_features=self.env_spec.action_space.flat_dim) - if config['agent:constant_std'] is None: # no constant std provided, so train it - if config['agent:std_state_dependent']: # std is dependent on state - self.logvar_head = nn.Linear(in_features=64, out_features=self.env_spec.action_space.flat_dim) - else: # std is independent of state - # Do not initialize it in `init_params()` - self.logvar_head = nn.Parameter(torch.full([self.env_spec.action_space.flat_dim], 0.01)) - - self.value_head = nn.Linear(in_features=64, out_features=1) - - def init_params(self, config): - gain = nn.init.calculate_gain(nonlinearity='tanh') - - nn.init.orthogonal_(self.fc1.weight, gain=gain) - nn.init.constant_(self.fc1.bias, 0.0) - - nn.init.orthogonal_(self.fc2.weight, gain=gain) - nn.init.constant_(self.fc2.bias, 0.0) - - nn.init.orthogonal_(self.mean_head.weight, gain=0.01) # small initial mean around 0. - nn.init.constant_(self.mean_head.bias, 0.0) - if config['agent:constant_std'] is None and config['agent:std_state_dependent']: - nn.init.orthogonal_(self.logvar_head.weight, gain=0.01) - nn.init.constant_(self.logvar_head.bias, 0.0) - - nn.init.orthogonal_(self.value_head.weight, gain=1.0) # no nonlinearity - nn.init.constant_(self.value_head.bias, 0.0) - - def forward(self, x): - # Output dictionary - network_out = {} - - # Flatten the input - x = x.flatten(start_dim=1) - - # Forward pass through feature layers - x = torch.tanh(self.fc1(x)) - x = torch.tanh(self.fc2(x)) - - # Forward pass through action layers and record the output - mean = self.mean_head(x) - network_out['mean'] = mean - - if self.config['agent:constant_std'] is None: # learned std - if self.config['agent:std_state_dependent']: # state-dependent std, so forward pass - logvar = self.logvar_head(x) - else: # state-independent, so directly use it - logvar = self.logvar_head.expand_as(mean) - network_out['logvar'] = logvar - - # Forward pass through value layer and record the output - state_value = self.value_head(x) - network_out['state_value'] = state_value - - return network_out - - -class CategoricalPolicy(BaseCategoricalPolicy): - def process_network_output(self, network_out): - return network_out - - -class GaussianPolicy(BaseGaussianPolicy): - def process_network_output(self, network_out): - return network_out - - def constraint_action(self, action): - # Limit the action with valid range - # Note that we assume all Continuous action space with same low and same high for each dimension - # and asymmetric (absolute values between low and high should be identical) - low = np.unique(self.env_spec.action_space.low) - high = np.unique(self.env_spec.action_space.high) - assert low.ndim == 1 and high.ndim == 1 - assert -low.item() == high.item() - - # Enforce valid action in [low, high] - action = torch.clamp(action, min=low.item(), max=high.item()) - - return action + return x diff --git a/examples/policy_gradient/reinforce/experiment.py b/examples/policy_gradient/reinforce/experiment.py index e085d0d1..7052fcf3 100644 --- a/examples/policy_gradient/reinforce/experiment.py +++ b/examples/policy_gradient/reinforce/experiment.py @@ -37,7 +37,6 @@ def make_configs(self): configurator.fixed('agent.std_state_dependent', False) # whether to learn std with state dependency configurator.fixed('agent.init_std', 1.0) # initial std for state-independent std - configurator.fixed('train.timestep', 1e6) # either 'train.iter' or 'train.timestep' configurator.fixed('train.N', 1) # number of trajectories per training iteration configurator.fixed('train.T', 200) # max allowed horizon diff --git a/examples/policy_gradient/vpg/experiment.py b/examples/policy_gradient/vpg/experiment.py index dc500e67..d2077b75 100644 --- a/examples/policy_gradient/vpg/experiment.py +++ b/examples/policy_gradient/vpg/experiment.py @@ -38,7 +38,6 @@ def make_configs(self): configurator.fixed('agent.std_state_dependent', False) # whether to learn std with state dependency configurator.fixed('agent.init_std', 0.5) # initial std for state-independent std - configurator.fixed('train.timestep', 1e6) # either 'train.iter' or 'train.timestep' configurator.fixed('train.N', 1) # number of trajectories per training iteration configurator.fixed('train.T', 200) # max allowed horizon diff --git a/lagom/agents/a2c_agent.py b/lagom/agents/a2c_agent.py index 3a8816e0..db7985bf 100644 --- a/lagom/agents/a2c_agent.py +++ b/lagom/agents/a2c_agent.py @@ -5,100 +5,108 @@ import torch.nn.functional as F from .base_agent import BaseAgent + from lagom.core.transform import Standardize class A2CAgent(BaseAgent): - """ - Advantage Actor-Critic (A2C) with option to use Generalized Advantage Estimate (GAE) + r"""`Advantage Actor-Critic`_ (A2C) with option to use Generalized Advantage Estimate (GAE) The main difference of A2C is to use bootstrapping for estimating the advantage function and training value function. - Reference: https://arxiv.org/abs/1602.01783 + .. _Advantage Actor-Critic: + https://arxiv.org/abs/1602.01783 + + Like `OpenAI baselines` we use fixed-length segments of experiment to compute returns and advantages. + + .. _OpenAI baselines: + https://blog.openai.com/baselines-acktr-a2c/ + + .. note:: - Note that we use fixed-length segments of experiment to compute returns and advantages. - https://blog.openai.com/baselines-acktr-a2c/ + Use :class:`SegmentRunner` to collect data, not :class:`TrajectoryRunner` - For this purpose, please use SegmentRunner, not TrajectoryRunner to collect data for A2CAgent. """ - def __init__(self, policy, optimizer, config, **kwargs): + def __init__(self, config, policy, optimizer, **kwargs): self.policy = policy self.optimizer = optimizer super().__init__(config, **kwargs) - self.accumulated_trained_timesteps = 0 + # accumulated trained timesteps + self.total_T = 0 def choose_action(self, obs): - # Convert to Tensor - # Note that the observation should be batched already (even if only one Segment) - if not torch.is_tensor(obs): - obs = torch.from_numpy(np.array(obs)).float() - obs = obs.to(self.device) # move to device - - # Call policy - # Note that all metrics should also be batched for SegmentRunner to work properly, check policy/network output. - out_policy = self.policy(obs) - - # Dictionary of output data - output = {} - output = {**out_policy} - - return output + if not torch.is_tensor(obs): # Tensor conversion, already batched observation + obs = torch.from_numpy(np.asarray(obs)).float().to(self.device) + + # Call policy: all metrics should be batched properly for Runner to work properly + out_policy = self.policy(obs, out_keys=['action', 'action_logprob', 'state_value', + 'entropy', 'perplexity']) + + return out_policy def learn(self, D): + out = {} + batch_policy_loss = [] batch_value_loss = [] batch_entropy_loss = [] batch_total_loss = [] - # Iterate over list of Segment in D - for segment in D: + for segment in D: # iterate over segments # Get all boostrapped discounted returns as estimate of Q Qs = segment.all_bootstrapped_discounted_returns - # TODO: when use GAE of TDs, really standardize it ? biased magnitude of learned value get wrong TD error - # Standardize advantage estimates if required - # encourage/discourage half of performed actions, respectively. - if self.config['agent:standardize']: - Qs = Standardize()(Qs) - - # Get all state values (without V_s_next in final transition) - Vs = segment.all_info('V_s') - + + # Standardize: encourage/discourage half of performed actions + if self.config['agent.standardize_Q']: + Qs = Standardize()(Qs).tolist() + + # Get all state values and intermediate terminal state and final state + Vs, finals = segment.all_V + final_Vs, final_dones = zip(*finals) + assert len(Vs) == len(segment.transitions) + # Advantage estimates - As = [Q - V.item() for Q, V in zip(Qs, Vs)] - + As = [Q - V.item() for Q, V in zip(Qs, Vs)] + # Get all log-probabilities and entropies logprobs = segment.all_info('action_logprob') entropies = segment.all_info('entropy') - + # Estimate policy gradient for all time steps and record all losses policy_loss = [] - value_loss = [] entropy_loss = [] + value_loss = [] for logprob, entropy, A, Q, V in zip(logprobs, entropies, As, Qs, Vs): policy_loss.append(-logprob*A) - value_loss.append(F.mse_loss(V, torch.tensor(Q).view_as(V).to(V.device))) entropy_loss.append(-entropy) + value_loss.append(F.mse_loss(V, torch.tensor(Q).view_as(V).to(V.device))) + for final_V, final_done in zip(final_Vs, final_dones): # learn terminal state value as zero + if final_done: + value_loss.append(F.mse_loss(final_V, torch.tensor(0.0).view_as(V).to(V.device))) # Average over losses for all time steps policy_loss = torch.stack(policy_loss).mean() - value_loss = torch.stack(value_loss).mean() entropy_loss = torch.stack(entropy_loss).mean() + value_loss = torch.stack(value_loss).mean() # Calculate total loss - value_coef = self.config['agent:value_coef'] - entropy_coef = self.config['agent:entropy_coef'] + entropy_coef = self.config['agent.entropy_coef'] + value_coef = self.config['agent.value_coef'] total_loss = policy_loss + value_coef*value_loss + entropy_coef*entropy_loss - + # Record all losses batch_policy_loss.append(policy_loss) - batch_value_loss.append(value_loss) batch_entropy_loss.append(entropy_loss) + batch_value_loss.append(value_loss) batch_total_loss.append(total_loss) - # Compute loss (average over segments) - loss = torch.stack(batch_total_loss).mean() # use stack because each element is zero-dim tensor + # Compute loss (average over segments): use `stack` as each is zero-dim + loss = torch.stack(batch_total_loss).mean() + policy_loss = torch.stack(batch_policy_loss).mean() + entropy_loss = torch.stack(batch_entropy_loss).mean() + value_loss = torch.stack(batch_value_loss).mean() # Zero-out gradient buffer self.optimizer.zero_grad() @@ -106,59 +114,38 @@ def learn(self, D): loss.backward() # Clip gradient norms if required - if self.config['agent:max_grad_norm'] is not None: + if self.config['agent.max_grad_norm'] is not None: nn.utils.clip_grad_norm_(parameters=self.policy.network.parameters(), - max_norm=self.config['agent:max_grad_norm'], + max_norm=self.config['agent.max_grad_norm'], norm_type=2) # Decay learning rate if required if hasattr(self, 'lr_scheduler'): - if 'train:iter' in self.config: # iteration-based training, so just increment epoch by default + if 'train.iter' in self.config: # iteration-based training, so just increment epoch by default self.lr_scheduler.step() - elif 'train:timestep' in self.config: # timestep-based training, increment with timesteps - self.lr_scheduler.step(self.accumulated_trained_timesteps) + elif 'train.timestep' in self.config: # timestep-based training, increment with timesteps + self.lr_scheduler.step(self.total_T) else: - raise KeyError('expected train:iter or train:timestep in config, but none of them exist') + raise KeyError('expected `train.iter` or `train.timestep` in config, but none of them exist') # Take a gradient step self.optimizer.step() # Accumulate trained timesteps - self.accumulated_trained_timesteps += sum([segment.T for segment in D]) - - # Output dictionary for different losses - # TODO: if no more backprop needed, record with .item(), save memory without store computation graph - output = {} - output['loss'] = loss - output['batch_policy_loss'] = batch_policy_loss - output['batch_value_loss'] = batch_value_loss - output['batch_entropy_loss'] = batch_entropy_loss - output['batch_total_loss'] = batch_total_loss - if hasattr(self, 'lr_scheduler'): - output['current_lr'] = self.lr_scheduler.get_lr() + self.total_T += sum([segment.T for segment in D]) - return output + # Output dictionary: use `item()` to save memory if no more backprop needed + out['loss'] = loss.item() + out['policy_loss'] = policy_loss.item() + out['entropy_loss'] = entropy_loss.item() + out['value_loss'] = value_loss.item() + if hasattr(self, 'lr_scheduler'): + out['current_lr'] = self.lr_scheduler.get_lr() + + return out - def save(self, filename): - self.policy.network.save(filename) + def save(self, f): + self.policy.network.save(f) - def load(self, filename): - self.policy.network.load(filename) - - -""" - - # Generalized Advantage Estimation (GAE) - all_TD = episode.all_TD - alpha = episode.gamma*self.config['GAE_lambda'] - GAE_advantages = ExponentialFactorCumSum(alpha=alpha).process(all_TD) - # Standardize advantages to [-1, 1], encourage/discourage half of actions - GAE_advantages = Standardize().process(GAE_advantages) - - - for logprob, V, Q, GAE_advantage, entropy in zip(log_probs, Vs, Qs, GAE_advantages, entropies): - policy_loss.append(-logprob*GAE_advantage) - value_loss.append(F.mse_loss(V, torch.Tensor([Q]).unsqueeze(0)).unsqueeze(0)) - entropy_loss.append(-entropy) - -""" + def load(self, f): + self.policy.network.load(f) diff --git a/lagom/agents/vpg_agent.py b/lagom/agents/vpg_agent.py index 8e4240eb..7c6cf848 100644 --- a/lagom/agents/vpg_agent.py +++ b/lagom/agents/vpg_agent.py @@ -89,7 +89,7 @@ def learn(self, D): loss = torch.stack(batch_total_loss).mean() policy_loss = torch.stack(batch_policy_loss).mean() entropy_loss = torch.stack(batch_entropy_loss).mean() - value_loss = torch.stack(batch_value_loss).mean() + value_loss = torch.stack(batch_value_loss).mean() # Zero-out gradient buffer self.optimizer.zero_grad() diff --git a/test/test_policies.py b/test/test_policies.py index d81f357f..0f168c67 100644 --- a/test/test_policies.py +++ b/test/test_policies.py @@ -131,7 +131,7 @@ def _check_policy(policy): assert np.allclose(policy.network.value_head.bias.detach().numpy(), 0.0) obs = torch.from_numpy(np.array(env_spec.env.reset())).float() - out_policy = policy(obs, out_keys=['action', 'action_logprob', 'entropy', 'perplexity']) + out_policy = policy(obs, out_keys=['action', 'action_logprob', 'state_value', 'entropy', 'perplexity']) assert isinstance(out_policy, dict) assert 'action' in out_policy