diff --git a/ros_nodes/franka_nn_node/mobilefranka_rl_node.py b/ros_nodes/franka_nn_node/mobilefranka_rl_node.py index f2946305..e3714674 100644 --- a/ros_nodes/franka_nn_node/mobilefranka_rl_node.py +++ b/ros_nodes/franka_nn_node/mobilefranka_rl_node.py @@ -13,11 +13,18 @@ import numpy as np #import moveit_commander import sys +from collections import defaultdict class MobileFrankaRLNode: - def __init__(self): + def __init__(self, argv): + + experiment = argv[1] + point = argv[2] + method = argv[3] + run_number = argv[4] + self.output_filename = "{}_{}_{}_{}.csv".format(experiment, point, method, run_number) self.arm_joint_sub = rospy.Subscriber('/franka_state_controller/joint_states', JointState, self.arm_callback) #self.gripper_joint_sub = rospy.Subscriber('/franka_gripper/joint_states', JointState, self.gripper_callback) @@ -29,13 +36,36 @@ def __init__(self): self.target_pub = rospy.Publisher("/target", PointStamped, queue_size=20) #self.ort_model = ort.InferenceSession("models/single_agent_mobilefranka.onnx") - self.ort_model = ort.InferenceSession("models/mobilefrankaMARL_cv.onnx") + #self.ort_model = ort.InferenceSession("models/mobilefrankaMARL_cv.onnx") #self.ort_model = ort.InferenceSession("models/m1_exp_reward_mobilefranka.onnx") #self.ort_model = ort.InferenceSession("mobilefranka_no_base_vel_yaw_fix_easier_target.onnx") - self.single_agent = False - self.use_cv = True + if experiment == "e1" or experiment == "e4": + self.arm_control = True + else: + self.arm_control = False + + if experiment == "e4": + self.base_control = False + else: + self.base_control = True + + if method == "baseline": + #self.ort_model = ort.InferenceSession("models/single_agent_mobilefranka.onnx") + self.ort_model = ort.InferenceSession("models/newmodels/baselinebest.onnx") + self.single_agent = True + self.use_cv = False + elif method == "m1": + self.ort_model = ort.InferenceSession("models/m1_exp_reward_mobilefranka.onnx") + #self.ort_model = ort.InferenceSession("models/newmodels/m1best.onnx") + self.single_agent = False + self.use_cv = False + elif method == "m2": + #self.ort_model = ort.InferenceSession("models/mobilefrankaMARL_cv.onnx") + self.ort_model = ort.InferenceSession("models/newmodels/m2best.onnx") + self.single_agent = False + self.use_cv = True self.joint_positions = np.zeros(9) self.joint_velocities = np.zeros(9) @@ -46,7 +76,25 @@ def __init__(self): default_joint_pos = [0.0, -0.7856, 0.0, -2.356, 0.0, 1.572, 0.7854, 0.035, 0.035] - self.target_pos = np.array([0.4, -0.6, 0.5]) + # p1 0.4, -0.6, 0.5 + # p2 0.3, 2.0, 0.7 + # p3 -2.0, -1.0, 0.4 + # set target position dynamically based on argument "point" + if point == "p1": + self.target_pos = np.array([0.4, -0.6, 0.5]) + elif point == "p2": + self.target_pos = np.array([0.3, 2.0, 0.7]) + elif point == "p3": + self.target_pos = np.array([-2.0, -1.0, 0.4]) + else: + print("invalid point") + return + + self.start_time = None + self.data = defaultdict(list) + rospy.on_shutdown(self.shutdown_hook) + + #self.target_pos = np.array([-2.0, -1.0, 0.5]) self.joint_targets = None self.base_position = None @@ -71,6 +119,12 @@ def __init__(self): #rospy.Timer(rospy.Duration(1/10.0), self.update_base_velocity) + def shutdown_hook(self): + import pandas as pd + print("\nExporting data to csv file: ", self.output_filename) + df = pd.DataFrame(self.data) + df.to_csv(self.output_filename, index=False) + def publish_target(self): target = PointStamped() target.header.frame_id = "universe" @@ -115,6 +169,7 @@ def update_base_pose(self, timer_event): try: optitrack_trans = self.tfBuffer.lookup_transform('universe', 'husky_link', rospy.get_rostime(), rospy.Duration(1.0)) self.first_position = np.array([optitrack_trans.transform.translation.x, optitrack_trans.transform.translation.y, optitrack_trans.transform.translation.z]) + self.start_time = rospy.Time.now() except (tf2_ros.LookupException, tf2_ros.ConnectivityException, tf2_ros.ExtrapolationException): print("tf error") return @@ -133,6 +188,9 @@ def update_base_pose(self, timer_event): #print("self.base_position ", self.base_position) #print("self.left_finger_position", self.left_finger_position) distance_from_target = np.linalg.norm([self.target_pos - self.left_finger_position]) + current_time = rospy.Time.now() + elapsed_time = current_time - self.start_time + print("base x:", '{:.3f}'.format(self.base_position[0])) print("base y:", '{:.3f}'.format(self.base_position[1])) print("base yaw:", '{:.3f}'.format(self.base_yaw)) @@ -143,11 +201,27 @@ def update_base_pose(self, timer_event): print("target y:", '{:.3f}'.format(self.target_pos[1])) print("target z:", '{:.3f}'.format(self.target_pos[2])) print("distance from target:", distance_from_target) + print("elapsed time:", elapsed_time.to_sec()) print("-----------------") # if distance_from_target < 0.08: # print("Goal reached!", "distance:", distance_from_target) # rospy.signal_shutdown() self.publish_target() + + current_time = rospy.Time.now() + elapsed_time = current_time - self.start_time + + self.data["base_x"].append(self.base_position[0]) + self.data["base_y"].append(self.base_position[1]) + self.data["base_yaw"].append(self.base_yaw) + self.data["left_finger_x"].append(self.left_finger_position[0]) + self.data["left_finger_y"].append(self.left_finger_position[1]) + self.data["left_finger_z"].append(self.left_finger_position[2]) + self.data["target_x"].append(self.target_pos[0]) + self.data["target_y"].append(self.target_pos[1]) + self.data["target_z"].append(self.target_pos[2]) + self.data["distance_from_target"].append(distance_from_target) + self.data["elapsed_time"].append(elapsed_time.to_sec()) except (tf2_ros.LookupException, tf2_ros.ConnectivityException, tf2_ros.ExtrapolationException): print("tf error") @@ -344,7 +418,8 @@ def send_control(self, timer_event): #print("joint_goal", joint_goal) #print("scaled_action", dof_speed_scales * self.dt * arm_action * 7.5) - #self.trajectory_goal_pub.publish(goal) + if self.arm_control: + self.trajectory_goal_pub.publish(goal) #self.base_vel_queue.append(base_action) @@ -360,7 +435,8 @@ def send_control(self, timer_event): #print("base cmd", twist.linear.x, twist.angular.z) #print("base_action:", base_action[:2]) - self.base_cmd_vel_pub.publish(twist) + if self.base_control: + self.base_cmd_vel_pub.publish(twist) # TODO Need to publish the base action as moving average of the last 10 actions # self.base_action_buffer.append(base_action) @@ -380,5 +456,5 @@ def send_control(self, timer_event): if __name__ == '__main__': rospy.init_node('rl_node', anonymous=True) - MobileFrankaRLNode() + MobileFrankaRLNode(sys.argv) rospy.spin() \ No newline at end of file