Skip to content

Commit

Permalink
Fix rep errors in Mini-cheetah data
Browse files Browse the repository at this point in the history
  • Loading branch information
Danfoa committed May 15, 2024
1 parent c93f10e commit 52f2c51
Show file tree
Hide file tree
Showing 2 changed files with 117 additions and 116 deletions.
69 changes: 45 additions & 24 deletions morpho_symm/data/DynamicsRecording.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def state_moments(self) -> [np.ndarray, np.ndarray]:

def compute_obs_moments(self, obs_name: str) -> [np.ndarray, np.ndarray]:
"""Compute the mean and standard deviation of observations."""
assert obs_name in self.recordings.keys(), f"Observation {obs_name} not found in recordings"
assert obs_name in self.recordings.keys(), f"Observation {obs_name} not found in recording"
is_symmetric_obs = obs_name in self.obs_representations.keys()
if is_symmetric_obs:
rep_obs = self.obs_representations[obs_name]
Expand Down Expand Up @@ -113,17 +113,17 @@ def compute_obs_moments(self, obs_name: str) -> [np.ndarray, np.ndarray]:

# TODO: Move this check to Unit test as it is computationally demanding to check this at runtime.
# Ensure the mean is equivalent to computing the mean of the orbit of the recording under the group action
aug_obs = []
for g in G.elements:
g_obs = np.einsum('...ij,...j->...i', rep_obs(g), obs_original_basis)
aug_obs.append(g_obs)

aug_obs = np.concatenate(aug_obs, axis=0) # Append over the trajectory dimension
mean_emp = np.mean(aug_obs, axis=(0, 1))
assert np.allclose(mean, mean_emp, rtol=1e-3, atol=1e-3), f"Mean {mean} != {mean_emp}"

var_emp = np.var(aug_obs, axis=(0, 1))
assert np.allclose(var, var_emp, rtol=1e-2, atol=1e-2), f"Var {var} != {var_emp}"
# aug_obs = []
# for g in G.elements:
# g_obs = np.einsum('...ij,...j->...i', rep_obs(g), obs_original_basis)
# aug_obs.append(g_obs)
#
# aug_obs = np.concatenate(aug_obs, axis=0) # Append over the trajectory dimension
# mean_emp = np.mean(aug_obs, axis=(0, 1))
# assert np.allclose(mean, mean_emp, rtol=1e-3, atol=1e-3), f"Mean {mean} != {mean_emp}"
#
# var_emp = np.var(aug_obs, axis=(0, 1))
# assert np.allclose(var, var_emp, rtol=1e-2, atol=1e-2), f"Var {var} != {var_emp}"
else:
mean = np.mean(np.asarray(self.recordings[obs_name]), axis=(0, 1))
var = np.var(np.asarray(self.recordings[obs_name]), axis=(0, 1))
Expand Down Expand Up @@ -151,6 +151,13 @@ def get_state_trajs(self, standardize: bool = False):

return state_trajs

def get_state_dim_names(self):
dim_names = []
for obs_name in self.state_obs:
obs_dim = self.obs_dims[obs_name]
dim_names += [f"{obs_name}:{i}" for i in range(obs_dim)]
return dim_names

def save_to_file(self, file_path: Path):
# Store representations and groups without serializing
if len(self.obs_representations) > 0:
Expand All @@ -169,6 +176,7 @@ def save_to_file(self, file_path: Path):
self.dynamics_parameters.pop('group', None)

with file_path.with_suffix(".pkl").open('wb') as file:
self._path = file_path.with_suffix(".pkl").absolute()
pickle.dump(self, file, protocol=pickle.HIGHEST_PROTOCOL)

@staticmethod
Expand Down Expand Up @@ -419,16 +427,11 @@ def split_train_val_test(
from morpho_symm.utils.mysc import TemporaryNumpySeed
with TemporaryNumpySeed(10): # Ensure deterministic behavior
# Decide to keep a ratio of the original trajectories
num_trajs = int(dyn_recording.info['num_traj'])
state_traj = dyn_recording.get_state_trajs()
assert state_traj.ndim == 3, f"Expectec (traj, time, state_dim) but got {state_traj.shape}"
num_trajs, time_horizon, state_dim = state_traj.shape
split_time = time_horizon > num_trajs
if split_time: # Do not discard entire trajectories, but rather parts of the trajectories
# Take the time horizon from the first observation
sample_obs = dyn_recording.recordings[dyn_recording.state_obs[0]]
if len(sample_obs.shape) == 3: # [traj, time, obs_dim]
time_horizon = sample_obs.shape[1]
elif len(sample_obs.shape) == 2: # [traj, obs_dim]
time_horizon = sample_obs.shape[0]
else:
raise RuntimeError(f"Invalid shape {sample_obs.shape} of {dyn_recording.state_obs[0]}")

num_samples = time_horizon
min_idx = 0
Expand All @@ -453,9 +456,27 @@ def split_train_val_test(
raise RuntimeError(f"Invalid shape {dyn_recording.recordings[obs_name].shape} of {obs_name}")
partitions_recordings[partition_name].recordings[obs_name] = data

return partitions_recordings['train'], partitions_recordings['val'], partitions_recordings['test']
else: # Discard entire trajectories
raise NotImplementedError()
else: # Select train/val/test from individual trajectories
num_samples = num_trajs
min_idx = 0
partitions_sample_idx = {partition: None for partition in partitions_names}
for partition_name, ratio in zip(partitions_names, partition_sizes):
max_idx = min_idx + int(num_samples * ratio)
partitions_sample_idx[partition_name] = list(range(min_idx, max_idx))
min_idx = min_idx + int(num_samples * ratio)

partitions_recordings = {partition: copy.deepcopy(dyn_recording) for partition in partitions_names}
for partition_name, sample_idx in partitions_sample_idx.items():
part_num_samples = len(sample_idx)
partitions_recordings[partition_name].info['num_traj'] = part_num_samples
partitions_recordings[partition_name].recordings = dict()
for obs_name in dyn_recording.recordings.keys():
data = dyn_recording.recordings[obs_name][sample_idx]
partitions_recordings[partition_name].recordings[obs_name] = data

return partitions_recordings['train'], partitions_recordings['val'], partitions_recordings['test']




def get_dynamics_dataset(train_shards: list[Path],
Expand Down
164 changes: 72 additions & 92 deletions morpho_symm/data/mini_cheetah/read_recordings.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def get_kinematic_three_rep(G: Group):
return rep_kin_three


def get_ground_reaction_forces_rep(G: Group, rep_kin_three: Representation):
def get_Rd_signals_on_kin_subchains(G: Group, rep_kin_three: Representation):
rep_R3 = G.representations['R3']
rep_F = {G.identity: np.eye(12, dtype=int)}
gens = [np.kron(rep_kin_three(g), rep_R3(g)) for g in G.generators]
Expand Down Expand Up @@ -73,25 +73,28 @@ def convert_mini_cheetah_raysim_recordings(data_path: Path):
rep_Rd = G.representations['R3'] # Representation on vectors in R^d
rep_Rd_pseudo = G.representations['R3_pseudo'] # Representation on pseudo vectors in R^d
rep_euler_xyz = G.representations['euler_xyz'] # Representation on Euler angles
rep_z = group_rep_from_gens(G, rep_H={h: rep_Rd(h)[2,2].reshape((1,1)) for h in G.elements if h != G.identity})

# Define observation variables and their group representations z
base_pos, base_pos_rep = state[:, :3], rep_Rd
base_z, base_z_rep = state[:, [2]], rep_z
base_vel, base_vel_rep = state[:, 3:6], rep_Rd
base_ori, base_ori_rep = state[:, 6:9], rep_euler_xyz
base_ang_vel, base_ang_vel_rep = state[:, 9:12], rep_Rd_pseudo # Pseudo vector
feet_pos, feet_pos_rep = state[:, 12:24], directsum([rep_Rd] * 4, name='Rd^4')
joint_vel, joint_vel_rep = state[:, 36:48], rep_TqQ_js
joint_torques, joint_torques_rep = state[:, 48:60], rep_TqQ_js
rep_kin_three = get_kinematic_three_rep(G)
gait, gait_rep = state[:, 60:64], rep_kin_three # TODO
ref_base_z, ref_base_z_rep = state[:, [64]], G.trivial_representation
ref_base_vel, ref_base_vel_rep = state[:, 65:68], rep_Rd
ref_base_ori, ref_base_ori_rep = state[:, 68:71], rep_Rd
ref_base_ang_vel, ref_base_ang_vel_rep = state[:, 71:74], rep_Rd_pseudo
ref_feet_pos, rep_feet_pos = state[:, 74:86], get_ground_reaction_forces_rep(G, rep_kin_three) # TODO

rep_kin_three = get_kinematic_three_rep(G) # Permutation of legs
rep_Rd_on_limbs = get_Rd_signals_on_kin_subchains(G, rep_kin_three) # Representation on R^3 on legs

rep_z = group_rep_from_gens(G, rep_H={h: rep_Rd(h)[2, 2].reshape((1, 1)) for h in G.elements if h != G.identity})
rep_z.name = "base_z"

# Define observation variables and their group representations

# Base body observations ___________________________________________________________________________________________
base_pos = state[:, :3] # Rep: rep_Rd
base_z = state[:, [2]] # Rep: rep_z
base_vel = state[:, 3:6] # Rep: rep_Rd
base_ori = state[:, 6:9] # Rep: rep_euler_xyz
base_ang_vel = state[:, 9:12] # Rep: rep_euler_xyz
ref_base_z = state[:, [64]] # Rep: rep_z
ref_base_vel = state[:, 65:68] # Rep: rep_Rd
ref_base_ori = state[:, 68:71] # Rep: rep_euler_xyz
ref_base_ang_vel = state[:, 71:74] # Rep: rep_euler_xyz
base_z_error = base_z - ref_base_z # Rep: rep_z
base_vel_error = base_vel - ref_base_vel # Rep: rep_Rd
base_ang_vel_error = base_ang_vel - ref_base_ang_vel # Rep: rep_euler_xyz
base_ori_error = base_ori - ref_base_ori # Rep: rep_euler_xyz
# Define the representation of the rotation matrix R that transforms the base orientation.
rep_rot_flat = {}
# R = Rotation.from_euler("xyz", base_ori[2]).as_matrix()
Expand All @@ -102,14 +105,14 @@ def convert_mini_cheetah_raysim_recordings(data_path: Path):
base_ori_R = np.asarray([Rotation.from_euler("xyz", ori).as_matrix() for ori in base_ori])
base_ori_R_flat = base_ori_R.reshape(base_ori.shape[0], -1)

# g = G.sample()
# g_R = rep_Rd(g) @ R @ rep_Rd(~g)
# vectorize R row-wise
# R_flat = R.reshape(-1,)
# g_R_flat = rep_rot_flat(g) @ R_flat
# g_RR = g_R_flat.reshape(3, 3)
# assert np.allclose(g_R, g_RR), "g_R and g_RR are not equal"

# Euclidean space observations _____________________________________________________________________________________
feet_pos = state[:, 12:24] # Rep: rep_Rd_on_limbs
gait = state[:, 60:64] # Rep: rep_kin_three
ref_feet_pos = state[:, 74:86] # Rep: rep_Rd_on_limbs
feet_pos_error = feet_pos - ref_feet_pos # Rep: rep_Rd_on_limbs
# Joint-Space observations _________________________________________________________________________________________
joint_vel = state[:, 36:48]
joint_torques = state[:, 48:60]
# Joint positions need to be converted to the unit circle parametrization [cos(q), sin(q)].
# For God’s sake, we need to avoid using PyBullet.
bullet_client = BulletClient(connection_mode=pybullet.DIRECT)
Expand All @@ -121,14 +124,10 @@ def convert_mini_cheetah_raysim_recordings(data_path: Path):
# Define joint positions [q1, q2, ..., qn] -> [cos(q1), sin(q1), ..., cos(qn), sin(qn)] format.
q_js_unit_circle_t = np.stack([cos_q_js, sin_q_js], axis=2)
q_js_unit_circle_t = q_js_unit_circle_t.reshape(q_js_unit_circle_t.shape[0], -1)
joint_pos, joint_pos_rep = q_js_unit_circle_t, rep_Q_js # Joints in angle not unit circle representation

# Compute Relative observations that affect the system evolution.
base_z_error = base_z - ref_base_z
base_vel_error = base_vel - ref_base_vel
base_ang_vel_error = base_ang_vel - ref_base_ang_vel
joint_pos_S1, joint_pos_rep = q_js_unit_circle_t, rep_Q_js # Joints in angle not unit circle representation
joint_pos = q_js_ms # Joints in angle representation

# Subsample the data by skippig by ignoring odd frames.
# Subsample the data by skippig by ignoring odd frames. ============================================================
dt_subsample = 3
base_pos = base_pos[::dt_subsample]
base_z = base_z[::dt_subsample]
Expand All @@ -138,6 +137,7 @@ def convert_mini_cheetah_raysim_recordings(data_path: Path):
base_ang_vel = base_ang_vel[::dt_subsample]
feet_pos = feet_pos[::dt_subsample]
joint_pos = joint_pos[::dt_subsample]
joint_pos_S1 = joint_pos_S1[::dt_subsample]
joint_vel = joint_vel[::dt_subsample]
joint_torques = joint_torques[::dt_subsample]
gait = gait[::dt_subsample]
Expand All @@ -146,10 +146,12 @@ def convert_mini_cheetah_raysim_recordings(data_path: Path):
ref_base_ori = ref_base_ori[::dt_subsample]
ref_base_ang_vel = ref_base_ang_vel[::dt_subsample]
ref_feet_pos = ref_feet_pos[::dt_subsample]
feet_pos_error = feet_pos_error[::dt_subsample]
base_z_error = base_z_error[::dt_subsample]
base_vel_error = base_vel_error[::dt_subsample]
base_ang_vel_error = base_ang_vel_error[::dt_subsample]

base_ori_error = base_ori_error[::dt_subsample]
# Define the dataset.
data_recording = DynamicsRecording(
description=f"Mini Cheetah {data_path.parent.parent.stem}",
info=dict(num_traj=1,
Expand All @@ -162,41 +164,41 @@ def convert_mini_cheetah_raysim_recordings(data_path: Path):
base_ori_R_flat=base_ori_R_flat[None, ...].astype(np.float32),
base_ang_vel=base_ang_vel[None, ...].astype(np.float32),
feet_pos=feet_pos[None, ...].astype(np.float32),
feet_pos_error=feet_pos_error[None, ...].astype(np.float32),
joint_pos=joint_pos[None, ...].astype(np.float32),
joint_pos_S1=joint_pos_S1[None, ...].astype(np.float32),
joint_vel=joint_vel[None, ...].astype(np.float32),
joint_torques=joint_torques[None, ...].astype(np.float32),
gait=gait[None, ...].astype(np.float32),
ref_base_vel=ref_base_vel[None, ...].astype(np.float32),
ref_base_ori=ref_base_ori[None, ...].astype(np.float32),
ref_base_ang_vel=ref_base_ang_vel[None, ...].astype(np.float32),
ref_feet_pos=ref_feet_pos[None, ...].astype(np.float32),
base_z_error=base_z_error[None, ...].astype(np.float32),
base_vel_error=base_vel_error[None, ...].astype(np.float32),
base_ang_vel_error=base_ang_vel_error[None, ...].astype(np.float32),
base_ori_error=base_ori_error[None, ...].astype(np.float32),
),
state_obs=('joint_pos', 'joint_vel', 'base_ori_R_flat', 'base_z_error', 'base_vel_error', 'base_ang_vel_error'),
state_obs=('joint_pos', 'joint_vel', 'base_z_error', 'base_ori', 'base_ori_error', 'base_vel_error', 'base_ang_vel_error'),
action_obs=('joint_torques',),
obs_representations=dict(base_pos=base_pos_rep,
base_z=G.trivial_representation,
base_vel=base_vel_rep,
base_ori=base_ori_rep,
obs_representations=dict(joint_pos=rep_TqQ_js, # Joint-Space observations
joint_pos_S1=rep_Q_js,
joint_vel=rep_TqQ_js,
joint_torques=rep_TqQ_js,
# Base body observations
base_pos=rep_Rd,
base_z=rep_z,
base_z_error=rep_z,
base_vel=rep_Rd,
base_vel_error=rep_Rd,
base_ori=rep_euler_xyz,
base_ori_R_flat=rep_rot_flat,
base_ang_vel=base_ang_vel_rep,
ref_base_vel=ref_base_vel_rep,
ref_base_ori=ref_base_ori_rep,
ref_base_ang_vel=ref_base_ang_vel_rep,
feet_pos=feet_pos_rep,
joint_pos=joint_pos_rep,
joint_vel=joint_vel_rep,
joint_torques=joint_torques_rep,
gait=gait_rep,
ref_feet_pos=rep_feet_pos,
base_z_error=G.trivial_representation,
base_vel_error=base_vel_rep,
base_ang_vel_error=base_ang_vel_rep,
base_ang_vel=rep_euler_xyz,
base_ang_vel_error=rep_euler_xyz,
base_ori_error=rep_euler_xyz,
# Euclidean space observations
feet_pos=rep_Rd_on_limbs,
feet_pos_error=rep_Rd_on_limbs,
gait=rep_kin_three,
),
# Ensure the angles in the unit circle are not disturbed by the normalization.
obs_moments=dict(joint_pos=(np.zeros(q_js_unit_circle_t.shape[-1]), np.ones(q_js_unit_circle_t.shape[-1]),),
obs_moments=dict(joint_pos_S1=(np.zeros(q_js_unit_circle_t.shape[-1]), np.ones(q_js_unit_circle_t.shape[-1]),),
base_ori_R_flat=(np.zeros(base_ori_R_flat.shape[-1]), np.ones(base_ori_R_flat.shape[-1]),),
)
)
Expand All @@ -207,38 +209,16 @@ def convert_mini_cheetah_raysim_recordings(data_path: Path):
continue
data_recording.compute_obs_moments(obs_name=obs_name)

train_record, val_recording, test_record = split_train_val_test(data_recording)

for part_record in [test_record, val_recording]:
# Do "Hard" data-augmentation, as we want to evaluate the capacity of the models to predict the
# physics of the dynamics of the system. Although data comes from a single trajectory, because of the
# equivariance of Newtonian physics, the models should be able to predict the dynamics of the system
# for symmetric trajectories.
for obs_name in part_record.recordings.keys():
obs_rep = part_record.obs_representations[obs_name]
obs_traj = part_record.recordings[obs_name]
orbit = [obs_traj]
for g in G.elements:
if g == G.identity: continue # Already added
orbit.append(np.einsum('...ij,...j->...i', obs_rep(g), obs_traj))
obs_traj_orbit = np.concatenate(orbit, axis=0)
part_record.recordings[obs_name] = obs_traj_orbit
part_record.info['num_traj'] = obs_traj_orbit.shape[0]

for partition_name, recording in zip(['train', 'val', 'test'], [train_record, val_recording, test_record]):
file_name = (f"n_trajs={recording.info['num_traj']}"
f"-frames={recording.info['trajectory_length']}"
f"-{partition_name}.pkl")
recording.save_to_file(data_path.parent.parent / file_name)
print(f"Dynamics Recording saved to {data_path.parent.parent / file_name}")

# file_path = data_path.parent.parent / "recording"
# data_recording.save_to_file(file_path)
print(f"Dynamics Recording saved to")


#
file_name = (f"n_trajs={data_recording.info['num_traj']}"
f"-frames={data_recording.info['trajectory_length']}.pkl")
data_recording.save_to_file(data_path.parent.parent / file_name)
print(f"Dynamics Recording saved to {data_path.parent.parent / file_name}")


if __name__ == "__main__":
data_path = Path("raysim_recordings/uneven_easy/forward_minus_0_4/heightmap_logger/state_reduced_nmpc.npy")
convert_mini_cheetah_raysim_recordings(data_path)
terrains = ["flat", "uneven_easy", "uneven_medium", "uneven_hard_squares"]
modes = ["forward_minus_0_4", "forward_minus_0_4_yawrate_0_4", "forward_minus_0_4_yawrate_minus_0_4"]
for terrain in terrains:
for mode in modes:
data_path = Path(f"raysim_recordings/{terrain}/{mode}/heightmap_logger/state_reduced_nmpc.npy")
convert_mini_cheetah_raysim_recordings(data_path)

0 comments on commit 52f2c51

Please sign in to comment.