From 3871d93763d3b49c4741e6daeaebbc605fe140dc Mon Sep 17 00:00:00 2001 From: Aravind Rajeswaran Date: Fri, 18 Jun 2021 15:27:28 -0700 Subject: [PATCH] Added success percentage logging and a script for plotting logs --- mjrl/algos/batch_reinforce.py | 5 ++++ mjrl/utils/plot_from_logs.py | 55 +++++++++++++++++++++++++++++++++++ mjrl/utils/train_agent.py | 5 ++++ 3 files changed, 65 insertions(+) create mode 100644 mjrl/utils/plot_from_logs.py diff --git a/mjrl/algos/batch_reinforce.py b/mjrl/algos/batch_reinforce.py index 359e5a0..adaf4c5 100644 --- a/mjrl/algos/batch_reinforce.py +++ b/mjrl/algos/batch_reinforce.py @@ -207,3 +207,8 @@ def log_rollout_statistics(self, paths): self.logger.log_kv('stoc_pol_std', std_return) self.logger.log_kv('stoc_pol_max', max_return) self.logger.log_kv('stoc_pol_min', min_return) + try: + success_rate = self.env.env.env.evaluate_success(paths) + self.logger.log_kv('rollout_success', success_rate) + except: + pass diff --git a/mjrl/utils/plot_from_logs.py b/mjrl/utils/plot_from_logs.py new file mode 100644 index 0000000..1bb4373 --- /dev/null +++ b/mjrl/utils/plot_from_logs.py @@ -0,0 +1,55 @@ +import os +import argparse +import pickle +import numpy as np +import matplotlib +matplotlib.use('Agg') +import matplotlib.pyplot as plt +colors = ['b', 'g', 'r', 'c', 'm', 'y', 'k'] + +parser = argparse.ArgumentParser(description='Script to explore the data generated by an experiment.') +parser.add_argument('--data', '-d', type=str, required=True, help='location of the .pickle log data file') +parser.add_argument('--output', '-o', type=str, required=True, help='location to store results as a png') +parser.add_argument('--xkey', '-x', type=str, default=None, help='the key to use for x axis in plots') +parser.add_argument('--xscale', '-s', type=int, default=1, help='scaling for the x axis (optional)') +args = parser.parse_args() + +# get inputs and setup output file +if '.png' in args.output: + OUT_FILE = args.output +else: + OUT_FILE = args.output + '/plot.png' +data = pickle.load(open(args.data, 'rb')) +xscale = 1 if args.xscale is None else args.xscale +if args.xkey == 'num_samples': + xscale = xscale if 'act_repeat' not in data.keys() else data['act_repeat'][-1] + +dict_keys = list(data.keys()) +for k in dict_keys: + if len(data[k]) == 1: del(data[k]) + +# plot layout +nplt = len(data.keys()) +ncol = 4 +nrow = int(np.ceil(nplt/ncol)) + +# plot data +xkey = args.xkey +start_idx = 2 +end_idx = max([len(data[k]) for k in data.keys()]) +xdata = np.arange(end_idx) if (xkey is None or xkey == 'None') else \ + [np.sum(data[xkey][:i+1]) * xscale for i in range(len(data[xkey]))] + +# make the plot +plt.figure(figsize=(15,15), dpi=60) +for idx, key in enumerate(data.keys()): + plt.subplot(nrow, ncol, idx+1) + plt.tight_layout() + try: + last_idx = min(end_idx, len(data[key])) + plt.plot(xdata[start_idx:last_idx], data[key][start_idx:last_idx], color=colors[idx%7], linewidth=3) + except: + pass + plt.title(key) + +plt.savefig(OUT_FILE, dpi=100, bbox_inches="tight") diff --git a/mjrl/utils/train_agent.py b/mjrl/utils/train_agent.py index d869f7d..688b638 100644 --- a/mjrl/utils/train_agent.py +++ b/mjrl/utils/train_agent.py @@ -114,6 +114,11 @@ def train_agent(job_name, agent, mean_pol_perf = np.mean([np.sum(path['rewards']) for path in eval_paths]) if agent.save_logs: agent.logger.log_kv('eval_score', mean_pol_perf) + try: + eval_success = e.env.env.evaluate_success(eval_paths) + agent.logger.log_kv('eval_success', eval_success) + except: + pass if i % save_freq == 0 and i > 0: if agent.save_logs: