Skip to content

Commit

Permalink
Added success percentage logging and a script for plotting logs
Browse files Browse the repository at this point in the history
  • Loading branch information
aravindr93 committed Jun 18, 2021
1 parent bbed03a commit 3871d93
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 0 deletions.
5 changes: 5 additions & 0 deletions mjrl/algos/batch_reinforce.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,3 +207,8 @@ def log_rollout_statistics(self, paths):
self.logger.log_kv('stoc_pol_std', std_return)
self.logger.log_kv('stoc_pol_max', max_return)
self.logger.log_kv('stoc_pol_min', min_return)
try:
success_rate = self.env.env.env.evaluate_success(paths)
self.logger.log_kv('rollout_success', success_rate)
except:
pass
55 changes: 55 additions & 0 deletions mjrl/utils/plot_from_logs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import os
import argparse
import pickle
import numpy as np
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
colors = ['b', 'g', 'r', 'c', 'm', 'y', 'k']

parser = argparse.ArgumentParser(description='Script to explore the data generated by an experiment.')
parser.add_argument('--data', '-d', type=str, required=True, help='location of the .pickle log data file')
parser.add_argument('--output', '-o', type=str, required=True, help='location to store results as a png')
parser.add_argument('--xkey', '-x', type=str, default=None, help='the key to use for x axis in plots')
parser.add_argument('--xscale', '-s', type=int, default=1, help='scaling for the x axis (optional)')
args = parser.parse_args()

# get inputs and setup output file
if '.png' in args.output:
OUT_FILE = args.output
else:
OUT_FILE = args.output + '/plot.png'
data = pickle.load(open(args.data, 'rb'))
xscale = 1 if args.xscale is None else args.xscale
if args.xkey == 'num_samples':
xscale = xscale if 'act_repeat' not in data.keys() else data['act_repeat'][-1]

dict_keys = list(data.keys())
for k in dict_keys:
if len(data[k]) == 1: del(data[k])

# plot layout
nplt = len(data.keys())
ncol = 4
nrow = int(np.ceil(nplt/ncol))

# plot data
xkey = args.xkey
start_idx = 2
end_idx = max([len(data[k]) for k in data.keys()])
xdata = np.arange(end_idx) if (xkey is None or xkey == 'None') else \
[np.sum(data[xkey][:i+1]) * xscale for i in range(len(data[xkey]))]

# make the plot
plt.figure(figsize=(15,15), dpi=60)
for idx, key in enumerate(data.keys()):
plt.subplot(nrow, ncol, idx+1)
plt.tight_layout()
try:
last_idx = min(end_idx, len(data[key]))
plt.plot(xdata[start_idx:last_idx], data[key][start_idx:last_idx], color=colors[idx%7], linewidth=3)
except:
pass
plt.title(key)

plt.savefig(OUT_FILE, dpi=100, bbox_inches="tight")
5 changes: 5 additions & 0 deletions mjrl/utils/train_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,11 @@ def train_agent(job_name, agent,
mean_pol_perf = np.mean([np.sum(path['rewards']) for path in eval_paths])
if agent.save_logs:
agent.logger.log_kv('eval_score', mean_pol_perf)
try:
eval_success = e.env.env.evaluate_success(eval_paths)
agent.logger.log_kv('eval_success', eval_success)
except:
pass

if i % save_freq == 0 and i > 0:
if agent.save_logs:
Expand Down

0 comments on commit 3871d93

Please sign in to comment.