Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
Former-commit-id: b73f9f094298c17cf8dbf4d921d666ed3df6b451 [formerly d0e5b23b985d3ede488cc0866ab16796bea317ba]
Former-commit-id: 1e1166290a20adfe31d678533f7022141ecc7af0
  • Loading branch information
zuoxingdong committed Nov 2, 2018
1 parent 54c9f0d commit 24cd620
Show file tree
Hide file tree
Showing 24 changed files with 1,116 additions and 758 deletions.
17 changes: 17 additions & 0 deletions docs/source/history.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,20 @@ lagom.history: History

.. autoclass:: Segment
:members:

Metrics
----------------

.. currentmodule:: lagom.history.metrics

.. autofunction:: terminal_state_from_trajectory

.. autofunction:: terminal_state_from_segment

.. autofunction:: final_state_from_trajectory

.. autofunction:: final_state_from_segment

.. autofunction:: bootstrapped_returns_from_trajectory

.. autofunction:: bootstrapped_returns_from_segment
113 changes: 17 additions & 96 deletions examples/policy_gradient/a2c/algo.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,55 +3,42 @@
from itertools import count

import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from lagom import set_global_seeds
from lagom import Logger
from lagom.utils import pickle_dump
from lagom.utils import set_global_seeds

from lagom import BaseAlgorithm
from lagom import pickle_dump

from lagom.envs import make_gym_env
from lagom.envs import make_vec_env
from lagom.envs import EnvSpec
from lagom.envs.vec_env import SerialVecEnv
from lagom.envs.vec_env import ParallelVecEnv
from lagom.envs.vec_env import VecStandardize

from lagom.core.policies import CategoricalPolicy
from lagom.core.policies import GaussianPolicy

from lagom.runner import TrajectoryRunner
from lagom.runner import SegmentRunner

from lagom.agents import A2CAgent

from model import Agent
from engine import Engine
from policy import Network
from policy import LSTM


class Algorithm(BaseAlgorithm):
def __call__(self, config, seed, device_str):
def __call__(self, config, seed, device):
set_global_seeds(seed)
device = torch.device(device_str)
logdir = Path(config['log.dir']) / str(config['ID']) / str(seed)

# Environment related

env = make_vec_env(vec_env_class=SerialVecEnv,
make_env=make_gym_env,
env_id=config['env.id'],
num_env=config['train.N'], # batched environment
init_seed=seed,
rolling=True)
init_seed=seed)
eval_env = make_vec_env(vec_env_class=SerialVecEnv,
make_env=make_gym_env,
env_id=config['env.id'],
num_env=config['eval.N'],
init_seed=seed,
rolling=False)
init_seed=seed)
if config['env.standardize']: # running averages of observation and reward
env = VecStandardize(venv=env,
use_obs=True,
Expand All @@ -60,102 +47,37 @@ def __call__(self, config, seed, device_str):
clip_reward=10.,
gamma=0.99,
eps=1e-8)
eval_env = VecStandardize(venv=eval_env, # remember to synchronize running averages during evaluation !!!
eval_env = VecStandardize(venv=eval_env,
use_obs=True,
use_reward=False, # do not process rewards, no training
clip_obs=env.clip_obs,
clip_reward=env.clip_reward,
gamma=env.gamma,
eps=env.eps,
constant_obs_mean=env.obs_runningavg.mu, # use current running average as constant
constant_obs_mean=env.obs_runningavg.mu,
constant_obs_std=env.obs_runningavg.sigma)
env_spec = EnvSpec(env)

agent = Agent(config, env_spec, device)

# Network and policy
if config['network.recurrent']:
network = LSTM(config=config, device=device, env_spec=env_spec)
else:
network = Network(config=config, device=device, env_spec=env_spec)
if env_spec.control_type == 'Discrete':
policy = CategoricalPolicy(config=config,
network=network,
env_spec=env_spec,
device=device,
learn_V=True)
elif env_spec.control_type == 'Continuous':
policy = GaussianPolicy(config=config,
network=network,
env_spec=env_spec,
device=device,
learn_V=True,
min_std=config['agent.min_std'],
std_style=config['agent.std_style'],
constant_std=config['agent.constant_std'],
std_state_dependent=config['agent.std_state_dependent'],
init_std=config['agent.init_std'])

# Optimizer and learning rate scheduler
optimizer = optim.Adam(policy.network.parameters(), lr=config['algo.lr'])
if config['algo.use_lr_scheduler']:
if 'train.iter' in config: # iteration-based
max_epoch = config['train.iter']
elif 'train.timestep' in config: # timestep-based
max_epoch = config['train.timestep'] + 1 # avoid zero lr in final iteration
lambda_f = lambda epoch: 1 - epoch/max_epoch
lr_scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_f)

# Agent
kwargs = {'device': device}
if config['algo.use_lr_scheduler']:
kwargs['lr_scheduler'] = lr_scheduler
agent = A2CAgent(config=config,
policy=policy,
optimizer=optimizer,
**kwargs)

# Runner
runner = SegmentRunner(agent=agent,
env=env,
gamma=config['algo.gamma'])
eval_runner = TrajectoryRunner(agent=agent,
env=eval_env,
gamma=1.0)
runner = SegmentRunner(config, agent, env)
eval_runner = TrajectoryRunner(config, agent, eval_env)

# Engine
engine = Engine(agent=agent,
runner=runner,
config=config,
eval_runner=eval_runner)
engine = Engine(agent, runner, config, eval_runner=eval_runner)

# Training and evaluation
train_logs = []
eval_logs = []

if config['network.recurrent']:
rnn_states_buffer = agent.policy.rnn_states # for SegmentRunner

for i in count():
if 'train.iter' in config and i >= config['train.iter']: # enough iterations
break
elif 'train.timestep' in config and agent.total_T >= config['train.timestep']: # enough timesteps
break

if config['network.recurrent']:
if isinstance(rnn_states_buffer, list): # LSTM: [h, c]
rnn_states_buffer = [buf.detach() for buf in rnn_states_buffer]
else:
rnn_states_buffer = rnn_states_buffer.detach()
agent.policy.rnn_states = rnn_states_buffer

train_output = engine.train(n=i)
train_output = engine.train(i)

# Logging
if i == 0 or (i+1) % config['log.record_interval'] == 0 or (i+1) % config['log.print_interval'] == 0:
train_log = engine.log_train(train_output)

if config['network.recurrent']:
rnn_states_buffer = agent.policy.rnn_states # for SegmentRunner

with torch.no_grad(): # disable grad, save memory
eval_output = engine.eval(n=i)
eval_log = engine.log_eval(eval_output)
Expand All @@ -164,7 +86,6 @@ def __call__(self, config, seed, device_str):
train_logs.append(train_log)
eval_logs.append(eval_log)

# Save all loggings
pickle_dump(obj=train_logs, f=logdir/'train_logs', ext='.pkl')
pickle_dump(obj=eval_logs, f=logdir/'eval_logs', ext='.pkl')

Expand Down
74 changes: 32 additions & 42 deletions examples/policy_gradient/a2c/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch

from lagom import Logger
from lagom import color_str
from lagom.utils import color_str

from lagom.engine import BaseEngine

Expand All @@ -15,12 +15,10 @@

class Engine(BaseEngine):
def train(self, n):
self.agent.policy.network.train() # train mode
self.agent.train()

# Collect a list of Segment
D = self.runner(T=self.config['train.T'])

# Train agent with collected data
out_agent = self.agent.learn(D)

train_output = {}
Expand All @@ -31,35 +29,31 @@ def train(self, n):
return train_output

def log_train(self, train_output, **kwargs):
# Unpack
D = train_output['D']
out_agent = train_output['out_agent']
n = train_output['n']

# Loggings
logger = Logger(name='train_logger')
logger.log('train_iteration', n+1) # starts from 1
if self.config['algo.use_lr_scheduler']:
logger.log('current_lr', out_agent['current_lr'])

logger.log('loss', out_agent['loss'])
logger.log('policy_loss', out_agent['policy_loss'])
logger.log('policy_entropy', -out_agent['entropy_loss']) # entropy: negative entropy loss
logger.log('value_loss', out_agent['value_loss'])
logger = Logger()
logger('train_iteration', n+1) # starts from 1
if 'current_lr' in out_agent:
logger('current_lr', out_agent['current_lr'])
logger('loss', out_agent['loss'])
logger('policy_loss', out_agent['policy_loss'])
logger('policy_entropy', -out_agent['entropy_loss'])
logger('value_loss', out_agent['value_loss'])

all_immediate_reward = [segment.all_r for segment in D]
num_timesteps = sum([segment.T for segment in D])

logger.log('num_segments', len(D))
logger.log('num_subsegments', sum([len(segment.trajectories) for segment in D]))
logger.log('num_timesteps', num_timesteps)
logger.log('accumulated_trained_timesteps', self.agent.total_T)
logger.log('average_immediate_reward', np.mean(all_immediate_reward))
logger.log('std_immediate_reward', np.std(all_immediate_reward))
logger.log('min_immediate_reward', np.min(all_immediate_reward))
logger.log('max_immediate_reward', np.max(all_immediate_reward))
logger('num_segments', len(D))
logger('num_subsegments', sum([len(segment.trajectories) for segment in D]))
logger('num_timesteps', num_timesteps)
logger('accumulated_trained_timesteps', self.agent.total_T)
logger('average_immediate_reward', np.mean(all_immediate_reward))
logger('std_immediate_reward', np.std(all_immediate_reward))
logger('min_immediate_reward', np.min(all_immediate_reward))
logger('max_immediate_reward', np.max(all_immediate_reward))

# Dump loggings
if n == 0 or (n+1) % self.config['log.print_interval'] == 0:
print('-'*50)
logger.dump(keys=None, index=None, indent=0)
Expand All @@ -68,16 +62,15 @@ def log_train(self, train_output, **kwargs):
return logger.logs

def eval(self, n):
self.agent.policy.network.eval() # evaluation mode
self.agent.eval()

# Synchronize running average of observations for evaluation
if self.config['env.standardize']:
self.eval_runner.env.constant_obs_mean = self.runner.env.obs_runningavg.mu
self.eval_runner.env.constant_obs_std = self.runner.env.obs_runningavg.sigma

# Collect a list of Trajectory
T = self.eval_runner.env.T
D = self.eval_runner(T=T)
D = self.eval_runner(T)

eval_output = {}
eval_output['D'] = D
Expand All @@ -87,29 +80,26 @@ def eval(self, n):
return eval_output

def log_eval(self, eval_output, **kwargs):
# Unpack
D = eval_output['D']
n = eval_output['n']
T = eval_output['T']

# Loggings
logger = Logger(name='eval_logger')
logger = Logger()

batch_returns = [sum(trajectory.all_r) for trajectory in D]
batch_T = [trajectory.T for trajectory in D]

logger.log('evaluation_iteration', n+1)
logger.log('num_trajectories', len(D))
logger.log('max_allowed_horizon', T)
logger.log('average_horizon', np.mean(batch_T))
logger.log('num_timesteps', np.sum(batch_T))
logger.log('accumulated_trained_timesteps', self.agent.total_T)
logger.log('average_return', np.mean(batch_returns))
logger.log('std_return', np.std(batch_returns))
logger.log('min_return', np.min(batch_returns))
logger.log('max_return', np.max(batch_returns))

# Dump loggings
logger('evaluation_iteration', n+1)
logger('num_trajectories', len(D))
logger('max_allowed_horizon', T)
logger('average_horizon', np.mean(batch_T))
logger('num_timesteps', np.sum(batch_T))
logger('accumulated_trained_timesteps', self.agent.total_T)
logger('average_return', np.mean(batch_returns))
logger('std_return', np.std(batch_returns))
logger('min_return', np.min(batch_returns))
logger('max_return', np.max(batch_returns))

if n == 0 or (n+1) % self.config['log.print_interval'] == 0:
print(color_str('+'*50, 'yellow', 'bold'))
logger.dump(keys=None, index=None, indent=0)
Expand Down
Loading

0 comments on commit 24cd620

Please sign in to comment.