forked from facebookresearch/pyrobot
-
Notifications
You must be signed in to change notification settings - Fork 0
/
locobot_kobuki.py
196 lines (159 loc) · 5.28 KB
/
locobot_kobuki.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
#!/usr/bin/env python
# Copyright (c) Facebook, Inc. and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
Locomotion without crashing example with the PyRobot API.
Associated launch command:
roslaunch locobot_control main.launch use_base:=true use_arm:=true base:=kobuki use_camera:=true
Follow the associated README for installation instructions.
"""
import sys
import crash_utils.model as model
from crash_utils.test import Tester
sys.modules["model"] = model
from pyrobot import Robot
import os
import errno
import time
import torch
import argparse
DEFAULT_PAN_TILT = [0, -0.0]
MODEL_URL = "https://www.dropbox.com/s/dr4npkwz9j6c9rk/checkpoint.pth.bst?dl=0"
SAVE_DIR = "./models"
STRAIGHT_VELOCITY = 0.2
TURNING_VELOCITY = 1.0
TIME_STEP = 0.5
STRAIGHT_THRESHOLD = 0.5
def download_if_not_present(model_path, url):
"""
Function that downloads a file from a url to a given location.
:param model_path: Path where the file should be downlaoded to
:param url: URL from where the file will be downloaded from
:type model_path: string
:type url: string
"""
if not os.path.isfile(model_path):
if not os.path.exists(os.path.dirname(model_path)):
try:
os.makedirs(os.path.dirname(model_path))
except OSError as exc: # Guard against race condition
if exc.errno != errno.EEXIST:
raise
print("CRASH MODEL NOT FOUND! DOWNLOADING IT!")
os.system("wget {} -O {}".format(url, model_path))
def go_straight(bot, vel=STRAIGHT_VELOCITY, t=TIME_STEP):
"""
Make the robot go straight
:param bot: A pyrobot.Robot object
:param vel: Velocity with which to go straight
:param t: Amount of time to go staight for
:type bot: pyrobot.Robot
:type vel: float
:type t: float
"""
print("Straight!!")
bot.base.set_vel(vel, 0.0, t)
def turn_left(bot, vel=TURNING_VELOCITY, t=TIME_STEP):
"""
Make the robot turn left
:param bot: A pyrobot.Robot object
:param vel: Velocity with which to turn left
:param t: Amount of time to turn left for
:type bot: pyrobot.Robot
:type vel: float
:type t: float
"""
print("Left!!")
bot.base.set_vel(0.0, vel, t)
def turn_right(bot, vel=TURNING_VELOCITY, t=TIME_STEP):
"""
Make the robot turn right
:param bot: A pyrobot.Robot object
:param vel: Velocity with which to turn right
:param t: Amount of time to turn right for
:type bot: pyrobot.Robot
:type vel: float
:type t: float
"""
print("Right!!")
bot.base.set_vel(0.0, -vel, t)
def main(args):
"""
This is the main function for running the locomotion without crashing demo.
"""
if args.display_images == True:
from pyrobot.utils.util import try_cv2_import
cv2 = try_cv2_import()
bot = Robot("locobot", base_config={"base_planner": "none"})
bot.camera.reset()
print("Setting pan: {}, tilt: {}".format(*DEFAULT_PAN_TILT))
bot.camera.set_pan_tilt(*DEFAULT_PAN_TILT, wait=True)
model_path = os.path.join(args.save_dir, "crash_model.pth")
download_if_not_present(model_path, args.model_url)
crash_model = torch.load(model_path)
evaluator = Tester(crash_model)
control_start = time.time()
hist = "straight"
for _ in range(args.n_loops):
start_time = time.time()
rgb, _ = bot.camera.get_rgb_depth()
evals = evaluator.test(rgb)
print(evals)
if evals[3] > STRAIGHT_THRESHOLD:
hist = "straight"
go_straight(bot)
else:
if hist == "straight":
if evals[1] > evals[2]:
hist = "left"
turn_left(bot)
else:
hist = "right"
turn_right(bot)
elif hist == "left":
turn_left(bot)
elif hist == "right":
turn_right(bot)
stop_time = time.time()
time_elapsed = stop_time - start_time
if args.display_images == True:
cv2.imshow("image", evaluator.image)
cv2.waitKey(10)
if time.time() - control_start >= args.n_secs:
print("Time limit exceeded")
break
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Process args for moving without crashing"
)
parser.add_argument(
"-n",
"--n_secs",
help="Number of seconds to run the avoiding crashing controller",
type=int,
default=60,
)
parser.add_argument(
"-l",
"--n_loops",
help="Number of loops to run the avoiding crashing controller",
type=int,
default=1000,
)
parser.add_argument(
"-u", "--model_url", help="URL to download model from", default=MODEL_URL
)
parser.add_argument(
"-s", "--save_dir", help="Directory to save model", default=SAVE_DIR
)
parser.add_argument(
"-d",
"--visualize",
help="True to visualize images at each control loop, False otherwise",
dest="display_images",
action="store_true",
)
parser.set_defaults(display_images=False)
args = parser.parse_args()
main(args)