Skip to content

Commit

Permalink
Control ros node with cmd line arguments
Browse files Browse the repository at this point in the history
  • Loading branch information
ranzuh committed Sep 5, 2023
1 parent d29cb47 commit d1725f5
Showing 1 changed file with 84 additions and 8 deletions.
92 changes: 84 additions & 8 deletions ros_nodes/franka_nn_node/mobilefranka_rl_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,18 @@
import numpy as np
#import moveit_commander
import sys
from collections import defaultdict


class MobileFrankaRLNode:

def __init__(self):
def __init__(self, argv):

experiment = argv[1]
point = argv[2]
method = argv[3]
run_number = argv[4]
self.output_filename = "{}_{}_{}_{}.csv".format(experiment, point, method, run_number)

self.arm_joint_sub = rospy.Subscriber('/franka_state_controller/joint_states', JointState, self.arm_callback)
#self.gripper_joint_sub = rospy.Subscriber('/franka_gripper/joint_states', JointState, self.gripper_callback)
Expand All @@ -29,13 +36,36 @@ def __init__(self):
self.target_pub = rospy.Publisher("/target", PointStamped, queue_size=20)

#self.ort_model = ort.InferenceSession("models/single_agent_mobilefranka.onnx")
self.ort_model = ort.InferenceSession("models/mobilefrankaMARL_cv.onnx")
#self.ort_model = ort.InferenceSession("models/mobilefrankaMARL_cv.onnx")
#self.ort_model = ort.InferenceSession("models/m1_exp_reward_mobilefranka.onnx")

#self.ort_model = ort.InferenceSession("mobilefranka_no_base_vel_yaw_fix_easier_target.onnx")

self.single_agent = False
self.use_cv = True
if experiment == "e1" or experiment == "e4":
self.arm_control = True
else:
self.arm_control = False

if experiment == "e4":
self.base_control = False
else:
self.base_control = True

if method == "baseline":
#self.ort_model = ort.InferenceSession("models/single_agent_mobilefranka.onnx")
self.ort_model = ort.InferenceSession("models/newmodels/baselinebest.onnx")
self.single_agent = True
self.use_cv = False
elif method == "m1":
self.ort_model = ort.InferenceSession("models/m1_exp_reward_mobilefranka.onnx")
#self.ort_model = ort.InferenceSession("models/newmodels/m1best.onnx")
self.single_agent = False
self.use_cv = False
elif method == "m2":
#self.ort_model = ort.InferenceSession("models/mobilefrankaMARL_cv.onnx")
self.ort_model = ort.InferenceSession("models/newmodels/m2best.onnx")
self.single_agent = False
self.use_cv = True

self.joint_positions = np.zeros(9)
self.joint_velocities = np.zeros(9)
Expand All @@ -46,7 +76,25 @@ def __init__(self):

default_joint_pos = [0.0, -0.7856, 0.0, -2.356, 0.0, 1.572, 0.7854, 0.035, 0.035]

self.target_pos = np.array([0.4, -0.6, 0.5])
# p1 0.4, -0.6, 0.5
# p2 0.3, 2.0, 0.7
# p3 -2.0, -1.0, 0.4
# set target position dynamically based on argument "point"
if point == "p1":
self.target_pos = np.array([0.4, -0.6, 0.5])
elif point == "p2":
self.target_pos = np.array([0.3, 2.0, 0.7])
elif point == "p3":
self.target_pos = np.array([-2.0, -1.0, 0.4])
else:
print("invalid point")
return

self.start_time = None
self.data = defaultdict(list)
rospy.on_shutdown(self.shutdown_hook)

#self.target_pos = np.array([-2.0, -1.0, 0.5])
self.joint_targets = None

self.base_position = None
Expand All @@ -71,6 +119,12 @@ def __init__(self):

#rospy.Timer(rospy.Duration(1/10.0), self.update_base_velocity)

def shutdown_hook(self):
import pandas as pd
print("\nExporting data to csv file: ", self.output_filename)
df = pd.DataFrame(self.data)
df.to_csv(self.output_filename, index=False)

def publish_target(self):
target = PointStamped()
target.header.frame_id = "universe"
Expand Down Expand Up @@ -115,6 +169,7 @@ def update_base_pose(self, timer_event):
try:
optitrack_trans = self.tfBuffer.lookup_transform('universe', 'husky_link', rospy.get_rostime(), rospy.Duration(1.0))
self.first_position = np.array([optitrack_trans.transform.translation.x, optitrack_trans.transform.translation.y, optitrack_trans.transform.translation.z])
self.start_time = rospy.Time.now()
except (tf2_ros.LookupException, tf2_ros.ConnectivityException, tf2_ros.ExtrapolationException):
print("tf error")
return
Expand All @@ -133,6 +188,9 @@ def update_base_pose(self, timer_event):
#print("self.base_position ", self.base_position)
#print("self.left_finger_position", self.left_finger_position)
distance_from_target = np.linalg.norm([self.target_pos - self.left_finger_position])
current_time = rospy.Time.now()
elapsed_time = current_time - self.start_time

print("base x:", '{:.3f}'.format(self.base_position[0]))
print("base y:", '{:.3f}'.format(self.base_position[1]))
print("base yaw:", '{:.3f}'.format(self.base_yaw))
Expand All @@ -143,11 +201,27 @@ def update_base_pose(self, timer_event):
print("target y:", '{:.3f}'.format(self.target_pos[1]))
print("target z:", '{:.3f}'.format(self.target_pos[2]))
print("distance from target:", distance_from_target)
print("elapsed time:", elapsed_time.to_sec())
print("-----------------")
# if distance_from_target < 0.08:
# print("Goal reached!", "distance:", distance_from_target)
# rospy.signal_shutdown()
self.publish_target()

current_time = rospy.Time.now()
elapsed_time = current_time - self.start_time

self.data["base_x"].append(self.base_position[0])
self.data["base_y"].append(self.base_position[1])
self.data["base_yaw"].append(self.base_yaw)
self.data["left_finger_x"].append(self.left_finger_position[0])
self.data["left_finger_y"].append(self.left_finger_position[1])
self.data["left_finger_z"].append(self.left_finger_position[2])
self.data["target_x"].append(self.target_pos[0])
self.data["target_y"].append(self.target_pos[1])
self.data["target_z"].append(self.target_pos[2])
self.data["distance_from_target"].append(distance_from_target)
self.data["elapsed_time"].append(elapsed_time.to_sec())

except (tf2_ros.LookupException, tf2_ros.ConnectivityException, tf2_ros.ExtrapolationException):
print("tf error")
Expand Down Expand Up @@ -344,7 +418,8 @@ def send_control(self, timer_event):
#print("joint_goal", joint_goal)
#print("scaled_action", dof_speed_scales * self.dt * arm_action * 7.5)

#self.trajectory_goal_pub.publish(goal)
if self.arm_control:
self.trajectory_goal_pub.publish(goal)

#self.base_vel_queue.append(base_action)

Expand All @@ -360,7 +435,8 @@ def send_control(self, timer_event):
#print("base cmd", twist.linear.x, twist.angular.z)

#print("base_action:", base_action[:2])
self.base_cmd_vel_pub.publish(twist)
if self.base_control:
self.base_cmd_vel_pub.publish(twist)

# TODO Need to publish the base action as moving average of the last 10 actions
# self.base_action_buffer.append(base_action)
Expand All @@ -380,5 +456,5 @@ def send_control(self, timer_event):

if __name__ == '__main__':
rospy.init_node('rl_node', anonymous=True)
MobileFrankaRLNode()
MobileFrankaRLNode(sys.argv)
rospy.spin()

0 comments on commit d1725f5

Please sign in to comment.