-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathparrot_predict.py
65 lines (58 loc) · 1.81 KB
/
parrot_predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
"""
Benchmark reinforcement learning (RL) algorithms from Stable Baselines 2.10.
Author: Gargi Vaidya & Vishnu Saj
- Note :
"""
import olympe
from parrotenv import ParrotEnv
from olympe.messages.ardrone3.Piloting import TakeOff, moveBy, Landing,moveTo
from olympe.messages.ardrone3.PilotingState import FlyingStateChanged,
import os
import gym
import numpy as np
import matplotlib.pyplot as plt
from stable_baselines import TD3
from stable_baselines.td3.policies import MlpPolicy
from stable_baselines.common.vec_env import DummyVecEnv
drone = olympe.Drone("10.202.0.1")
drone.connection()
assert drone(TakeOff()>> FlyingStateChanged(state="hovering", _timeout=5)).wait().success()
# Define the waypoints
A=[3,-3,3]
B=[3,3,5]
C=[-3,3,2]
D=[-3,-3,3]
obs=[0,0,0]
# Load the trained RL model
model = TD3.load("./tmp/best_model.zip")
# Evaluate model from origin state to waypoint A
env = ParrotEnv(destination = A, drone= drone)
done = 0
while not done:
action, _states = model.predict(obs)
obs, rewards, done, info = env.step(action)
env.render()
# Evaluate model from origin state to waypoint B
env = ParrotEnv(destination = B, drone= drone)
done = 0
while not done:
action, _states = model.predict(obs)
obs, rewards, done, info = env.step(action)
env.render()
# Evaluate model from origin state to waypoint C
env = ParrotEnv(destination = C, drone= drone)
done = 0
while not done:
action, _states = model.predict(obs)
obs, rewards, done, info = env.step(action)
env.render()
# Evaluate model from origin state to waypoint A
env = ParrotEnv(destination = D, drone= drone)
done = 0
while not done:
action, _states = model.predict(obs)
obs, rewards, done, info = env.step(action)
env.render()
# Land the drone
assert drone(Landing()).wait().success()
drone.disconnection()