Skip to content

Commit

Permalink
Update callback function to save model
Browse files Browse the repository at this point in the history
  • Loading branch information
Gargi Vaidya committed Feb 13, 2021
1 parent aaaca69 commit 28eb3dc
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 21 deletions.
34 changes: 17 additions & 17 deletions parrot_training.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,21 @@
# Continuous Action Parrot Drone Environment 3D Space Training Script
"""
Benchmark reinforcement learning (RL) algorithms from Stable Baselines 2.10.
Author: Gargi Vaidya & Vishnu Saj
- Note : Modify the RL algorithm from StableBaselines and tune the hyperparameters.
"""
import olympe
from matplotlib import pyplot as plt
from parrotenv import ParrotEnv
import subprocess
from olympe.messages.ardrone3.Piloting import TakeOff, moveBy, Landing,moveTo,PCMD
from olympe.messages.ardrone3.PilotingState import FlyingStateChanged, AttitudeChanged, moveByChanged, AltitudeChanged, PositionChanged

drone = olympe.Drone("10.202.0.1")
drone.connection()
command = "echo '{\"jsonrpc\": \"2.0\", \"method\": \"SetParam\", \"params\": {\"machine\":\"anafi4k\", \"object\":\"lipobattery/lipobattery\", \"parameter\":\"discharge_speed_factor\", \"value\":\"0\"}, \"id\": 1}' | curl -d @- http://localhost:8383 | python -m json.tool"
subprocess.run(command, shell=True)
assert drone(TakeOff()>> FlyingStateChanged(state="hovering", _timeout=5)).wait().success()
#drone.start_piloting()

import os
import csv
import gym
import csv
import numpy as np
import matplotlib.pyplot as plt

from stable_baselines import TD3
from stable_baselines.td3.policies import MlpPolicy
from stable_baselines.common.vec_env import DummyVecEnv
Expand All @@ -29,6 +24,13 @@
from stable_baselines.results_plotter import load_results, ts2xy
from stable_baselines.common.callbacks import BaseCallback

# Define the drone model and and take-off
drone = olympe.Drone("10.202.0.1")
drone.connection()
command = "echo '{\"jsonrpc\": \"2.0\", \"method\": \"SetParam\", \"params\": {\"machine\":\"anafi4k\", \"object\":\"lipobattery/lipobattery\", \"parameter\":\"discharge_speed_factor\", \"value\":\"0\"}, \"id\": 1}' | curl -d @- http://localhost:8383 | python -m json.tool"
subprocess.run(command, shell=True)
assert drone(TakeOff()>> FlyingStateChanged(state="hovering", _timeout=5)).wait().success()

class SaveOnBestTrainingRewardCallback(BaseCallback):
"""
Callback for saving a model (the check is done every ``check_freq`` steps)
Expand Down Expand Up @@ -72,28 +74,26 @@ def _on_step(self) -> bool:
self.model.save(self.save_path)

return True
# Stores a csv file for the episode reward
heading = ["Timestep", "Reward"]
with open('reward.csv', 'w', newline='') as csvFile:
writer = csv.writer(csvFile)
writer.writerow(heading)
csvFile.close()

# Create log dir
# Create log directory
log_dir = "tmp1/"
os.makedirs(log_dir, exist_ok=True)



env = ParrotEnv(destination = [0,0,1], drone= drone)
#env.reset()
env = Monitor(env, log_dir)

#model = DQN(MlpPolicy, env, verbose=1,tensorboard_log="./dqn_lunar_tensorboard/", exploration_fraction = 0.2, learning_rate = 0.001)
#model = TD3(MlpPolicy, env, verbose=1, learning_rate = 0.0005,tensorboard_log="./td3_parrot_tensorboard/", buffer_size = 25000)
model = TD3.load("./tmp1/best_model.zip",env)
model = TD3(MlpPolicy, env, verbose=1, learning_rate = 0.0005,tensorboard_log="./td3_parrot_tensorboard/", buffer_size = 25000)
callback = SaveOnBestTrainingRewardCallback(check_freq=1000, log_dir=log_dir)
#model = DQN(MlpPolicy, env, verbose=1)
model.learn(total_timesteps=30000, log_interval=10,tb_log_name = "td3_run_3d", callback = callback)
model.save("td3_parrot_3d")

# Land the drone and disconnect.
assert drone(Landing()).wait().success()
drone.disconnection()
6 changes: 2 additions & 4 deletions parrotenv.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,14 +152,12 @@ def step(self, action):
if z>5:
z_act = min(0,action[2])

self.drone(PCMD(1, y_act, x_act, 0, z_act, timestampAndSeqNum=0, _timeout=10)>> FlyingStateChanged(state="hovering", _timeout=5)).wait()


self.drone(PCMD(1, y_act, x_act, 0, z_act, timestampAndSeqNum=0, _timeout=10)>> FlyingStateChanged(state="hovering", _timeout=5)).wait()
self.pos_feedback() # Update state of the drone in self.agent_pos
obs = [self.agent_pos[0]-self.destination[0],self.agent_pos[1]-self.destination[1],self.agent_pos[2]-self.destination[2]]
d = self.distance([obs[0],obs[1],obs[2]])

#Terminating Condition
#Terminating Condition and reward design
done = bool(d < 0.5)
if bool(d < 0.5):
reward = +100
Expand Down

0 comments on commit 28eb3dc

Please sign in to comment.