From 28eb3dc2d3803881b26a8b5644124fa0049d915a Mon Sep 17 00:00:00 2001 From: Gargi Vaidya Date: Sat, 13 Feb 2021 11:18:59 -0600 Subject: [PATCH] Update callback function to save model --- parrot_training.py | 34 +++++++++++++++++----------------- parrotenv.py | 6 ++---- 2 files changed, 19 insertions(+), 21 deletions(-) diff --git a/parrot_training.py b/parrot_training.py index 97e7955..8649577 100644 --- a/parrot_training.py +++ b/parrot_training.py @@ -1,26 +1,21 @@ -# Continuous Action Parrot Drone Environment 3D Space Training Script +""" +Benchmark reinforcement learning (RL) algorithms from Stable Baselines 2.10. +Author: Gargi Vaidya & Vishnu Saj +- Note : Modify the RL algorithm from StableBaselines and tune the hyperparameters. +""" import olympe from matplotlib import pyplot as plt from parrotenv import ParrotEnv import subprocess from olympe.messages.ardrone3.Piloting import TakeOff, moveBy, Landing,moveTo,PCMD from olympe.messages.ardrone3.PilotingState import FlyingStateChanged, AttitudeChanged, moveByChanged, AltitudeChanged, PositionChanged - -drone = olympe.Drone("10.202.0.1") -drone.connection() -command = "echo '{\"jsonrpc\": \"2.0\", \"method\": \"SetParam\", \"params\": {\"machine\":\"anafi4k\", \"object\":\"lipobattery/lipobattery\", \"parameter\":\"discharge_speed_factor\", \"value\":\"0\"}, \"id\": 1}' | curl -d @- http://localhost:8383 | python -m json.tool" -subprocess.run(command, shell=True) -assert drone(TakeOff()>> FlyingStateChanged(state="hovering", _timeout=5)).wait().success() -#drone.start_piloting() - import os import csv import gym import csv import numpy as np import matplotlib.pyplot as plt - from stable_baselines import TD3 from stable_baselines.td3.policies import MlpPolicy from stable_baselines.common.vec_env import DummyVecEnv @@ -29,6 +24,13 @@ from stable_baselines.results_plotter import load_results, ts2xy from stable_baselines.common.callbacks import BaseCallback +# Define the drone model and and take-off +drone = olympe.Drone("10.202.0.1") +drone.connection() +command = "echo '{\"jsonrpc\": \"2.0\", \"method\": \"SetParam\", \"params\": {\"machine\":\"anafi4k\", \"object\":\"lipobattery/lipobattery\", \"parameter\":\"discharge_speed_factor\", \"value\":\"0\"}, \"id\": 1}' | curl -d @- http://localhost:8383 | python -m json.tool" +subprocess.run(command, shell=True) +assert drone(TakeOff()>> FlyingStateChanged(state="hovering", _timeout=5)).wait().success() + class SaveOnBestTrainingRewardCallback(BaseCallback): """ Callback for saving a model (the check is done every ``check_freq`` steps) @@ -72,28 +74,26 @@ def _on_step(self) -> bool: self.model.save(self.save_path) return True +# Stores a csv file for the episode reward heading = ["Timestep", "Reward"] with open('reward.csv', 'w', newline='') as csvFile: writer = csv.writer(csvFile) writer.writerow(heading) csvFile.close() -# Create log dir +# Create log directory log_dir = "tmp1/" os.makedirs(log_dir, exist_ok=True) - - env = ParrotEnv(destination = [0,0,1], drone= drone) #env.reset() env = Monitor(env, log_dir) -#model = DQN(MlpPolicy, env, verbose=1,tensorboard_log="./dqn_lunar_tensorboard/", exploration_fraction = 0.2, learning_rate = 0.001) -#model = TD3(MlpPolicy, env, verbose=1, learning_rate = 0.0005,tensorboard_log="./td3_parrot_tensorboard/", buffer_size = 25000) -model = TD3.load("./tmp1/best_model.zip",env) +model = TD3(MlpPolicy, env, verbose=1, learning_rate = 0.0005,tensorboard_log="./td3_parrot_tensorboard/", buffer_size = 25000) callback = SaveOnBestTrainingRewardCallback(check_freq=1000, log_dir=log_dir) -#model = DQN(MlpPolicy, env, verbose=1) model.learn(total_timesteps=30000, log_interval=10,tb_log_name = "td3_run_3d", callback = callback) model.save("td3_parrot_3d") + +# Land the drone and disconnect. assert drone(Landing()).wait().success() drone.disconnection() diff --git a/parrotenv.py b/parrotenv.py index 58f1116..5b0cfbc 100644 --- a/parrotenv.py +++ b/parrotenv.py @@ -152,14 +152,12 @@ def step(self, action): if z>5: z_act = min(0,action[2]) - self.drone(PCMD(1, y_act, x_act, 0, z_act, timestampAndSeqNum=0, _timeout=10)>> FlyingStateChanged(state="hovering", _timeout=5)).wait() - - + self.drone(PCMD(1, y_act, x_act, 0, z_act, timestampAndSeqNum=0, _timeout=10)>> FlyingStateChanged(state="hovering", _timeout=5)).wait() self.pos_feedback() # Update state of the drone in self.agent_pos obs = [self.agent_pos[0]-self.destination[0],self.agent_pos[1]-self.destination[1],self.agent_pos[2]-self.destination[2]] d = self.distance([obs[0],obs[1],obs[2]]) - #Terminating Condition + #Terminating Condition and reward design done = bool(d < 0.5) if bool(d < 0.5): reward = +100