Skip to content

Commit

Permalink
Fix policy visualization
Browse files Browse the repository at this point in the history
  • Loading branch information
kngwyu committed Nov 14, 2024
1 parent 99f40f4 commit 281445d
Showing 1 changed file with 11 additions and 10 deletions.
21 changes: 11 additions & 10 deletions src/emevo/analysis/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,9 @@ def draw_cf_policy(
# Arrow points
center = Vec2d(max_force * 1.5, max_force * 1.5)
unit = Vec2d(0.0, 1.0)
u_left = unit.rotated(math.pi * 1.25 + rotation)
u_right = unit.rotated(math.pi * 0.75 + rotation)
s_left = u_left * max_force * 0.5 + center
s_right = u_right * max_force * 0.5 + center
d_unit = unit.rotated(math.pi)
s_left = unit.rotated(math.pi * 1.25 + rotation) * max_force * 0.5 + center
s_right = unit.rotated(math.pi * 0.75 + rotation) * max_force * 0.5 + center
# Draw the arrows
for title, policy_mean, ax in zip(names, policy_means, np.ravel(axes)):
# Misc
Expand All @@ -56,10 +55,11 @@ def draw_cf_policy(
circle = Circle((center.x, center.y), max_force * 0.5, fill=False)
ax.add_patch(circle)
# Left
d_left = u_left * policy_mean[0].item()
d_left = d_unit * policy_mean[0].item()
s_left_shifted = s_left - d_left
arrow = Arrow(
s_left.x,
s_left.y,
s_left_shifted.x,
s_left_shifted.y,
d_left.x,
d_left.y,
# 10% of the width? Looks thinner...
Expand All @@ -68,10 +68,11 @@ def draw_cf_policy(
)
ax.add_patch(arrow)
# Right
d_right = u_right * policy_mean[1].item()
d_right = d_unit * policy_mean[1].item()
s_right_shifted = s_right - d_right
arrow = Arrow(
s_right.x,
s_right.y,
s_right_shifted.x,
s_right_shifted.y,
d_right.x,
d_right.y,
width=max_force * 0.3,
Expand Down

0 comments on commit 281445d

Please sign in to comment.