Skip to content

Commit

Permalink
Fix misc in vis
Browse files Browse the repository at this point in the history
  • Loading branch information
kngwyu committed Nov 21, 2024
1 parent 2cb8e44 commit 771bab5
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 10 deletions.
7 changes: 3 additions & 4 deletions experiments/cf_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -644,10 +644,9 @@ def vis_policy(
sensor_color=np.array([0.0, 0.0, 0.0, 0.5], dtype=np.float32),
)
# I don't know why this works...
visualizer.render(env_state.physics)
visualizer.show()
visualizer.render(env_state.physics)
visualizer.show()
for _ in range(3):
visualizer.render(env_state.physics)
visualizer.show()
env._sensor_index = ag_idx # type:ignore
visualizer.render(env_state.physics)
images.append(visualizer.get_image())
Expand Down
15 changes: 9 additions & 6 deletions src/emevo/analysis/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,29 +94,32 @@ def draw_cf_policy_multi(
) -> None:
n_policies = len(names)
n_obs, n_policies = policy_means.shape[:2]
observations = [f"Observation {i+1}" for i in range(n_obs)]
fig, axes = plt.subplots(
nrows=n_obs,
ncols=n_policies,
figsize=(n_policies * fig_unit, n_obs * fig_unit),
nrows=n_policies,
ncols=n_obs,
figsize=(n_obs * fig_unit, n_policies * fig_unit),
)
fig.tight_layout()
# Arrow points
center = Vec2d(max_force * 1.5, max_force * 1.5)
unit = Vec2d(0.0, 1.0)
# Draw the arrows
for i, (title, rot) in enumerate(zip(names, rotations)):
for j, (obs_title, rot) in enumerate(zip(observations, rotations)):
d_unit = unit.rotated(rot)
s_left = unit.rotated(math.pi * 1.25 + rot) * max_force * 0.5 + center
s_right = unit.rotated(math.pi * 0.75 + rot) * max_force * 0.5 + center
for j, policy_mean in enumerate(policy_means[i]):
for i, policy_mean in enumerate(policy_means[j]):
ax = axes[i][j]
ax.set_xlim(0, max_force * 3)
ax.set_ylim(0, max_force * 3)
ax.set_xticks([])
ax.set_yticks([])
ax.set_aspect("equal", adjustable="box")
if i == 0:
ax.set_title(title)
ax.set_title(obs_title)
if j == 0:
ax.set_ylabel(names[i])
# Circle
circle = Circle((center.x, center.y), max_force * 0.5, fill=False)
ax.add_patch(circle)
Expand Down

0 comments on commit 771bab5

Please sign in to comment.