Skip to content

Commit

Permalink
[widget] Make foods draggable
Browse files Browse the repository at this point in the history
  • Loading branch information
kngwyu committed Nov 16, 2024
1 parent ad2c585 commit a4d06fd
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 24 deletions.
74 changes: 55 additions & 19 deletions src/emevo/analysis/mgl_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,6 @@ def __init__(
self._scaling = x_range / figsize[0], y_range / figsize[1]
self._phys_state = saved_physics
self._index = start
# For dragging
self._last_mouse_pos = None
self._xy_max = jnp.expand_dims(jnp.array([x_range, y_range]), axis=0)
self._dragged_state = None
self._make_renderer = partial(
MglRenderer,
Expand Down Expand Up @@ -100,7 +97,13 @@ def __init__(
self.setFixedSize(*self._figsize)
self.setMouseTracking(True)
self._ctx, self._fbo = None, None
# For dragging
self._last_mouse_pos = None
self._dragging_agent = False
self._dragging_food = False
self._xy_max = jnp.expand_dims(jnp.array([x_range, y_range]), axis=0)
self._selected_slot = 0
self._selected_food_slot = 0

def _scale_position(self, position: QPointF) -> tuple[float, float]:
return (
Expand Down Expand Up @@ -152,42 +155,75 @@ def exitable(self) -> bool:
return self._end_index - 1 <= self._index

def mousePressEvent(self, evt: QMouseEvent) -> None: # type: ignore
if evt.button() != Qt.LeftButton:
return
position = self._scale_position(evt.position())
circle = self._get_stated().circle
overlap = _overlap(
jnp.array(position),
sd = self._get_stated()
posarray = jnp.array(position)

def _get_selected(state: State, shape: Circle) -> int | None:
overlap = _overlap(posarray, shape, state)
(selected,) = jnp.nonzero(overlap)
if 0 < selected.shape[0]:
return selected[0].item()
else:
return None

selected = _get_selected(
sd.circle,
self._env._physics.shaped.circle,
circle,
)
(selected,) = jnp.nonzero(overlap)
if 0 < selected.shape[0]:
self._selected_slot = selected[0].item()
if selected is not None:
self._selected_slot = selected
self.selectionChanged.emit(self._selected_slot, self._index)

# Initialize dragging
if evt.button() == Qt.LeftButton and self._paused:
if self._paused:
self._last_mouse_pos = Vec2d(*position)
self._dragging_agent = True

selected = _get_selected(
sd.static_circle,
self._env._physics.shaped.static_circle,
)
if selected is not None and self._paused:
self._selected_food_slot = selected
self._last_mouse_pos = Vec2d(*position)
self._dragging_food = True

def mouseReleaseEvent(self, evt: QMouseEvent) -> None:
if evt.button() == Qt.LeftButton:
self._last_mouse_pos = None
self._dragging_food = False
self._dragging_agent = False

def mouseMoveEvent(self, evt: QMouseEvent) -> None:
current_pos = Vec2d(*self._scale_position(evt.position()))

if self._selected_slot is not None and self._last_mouse_pos is not None:
dragging = self._dragging_agent or self._dragging_food
if self._last_mouse_pos is not None and dragging:
# Compute dx/dy
dxy = current_pos - self._last_mouse_pos

# Update the physics state
stated = self._get_stated()
circle = stated.circle
xy = jnp.clip(
circle.p.xy.at[self._selected_slot].add(jnp.array(dxy)),
min=0.0,
max=self._xy_max,
)
self._dragged_state = stated.nested_replace("circle.p.xy", xy)
if self._dragging_agent:
circle = stated.circle
xy = jnp.clip(
circle.p.xy.at[self._selected_slot].add(jnp.array(dxy)),
min=self._env._agent_radius,
max=self._xy_max - self._env._agent_radius,
)
self._dragged_state = stated.nested_replace("circle.p.xy", xy)
elif self._dragging_food:
static_circle = stated.static_circle
xy = jnp.clip(
static_circle.p.xy.at[self._selected_food_slot].add(jnp.array(dxy)),
min=self._env._food_radius,
max=self._xy_max - self._env._food_radius,
)
self._dragged_state = stated.nested_replace("static_circle.p.xy", xy)

self._last_mouse_pos = current_pos
self.update()

Expand Down
6 changes: 1 addition & 5 deletions src/emevo/analysis/qt_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,6 @@ def __init__(
self._energy_cm = mpl.colormaps["YlGnBu"]
self._n_children_cm = mpl.colormaps["PuBuGn"]
self._food_cm = mpl.colormaps["plasma"]
self.__cm = mpl.colormaps["plasma"]
self._norm = mc.Normalize(vmin=0.0, vmax=1.0)
self._cm_fixed_minmax = {} if cm_fixed_minmax is None else cm_fixed_minmax
if profile_and_rewards is not None:
Expand Down Expand Up @@ -264,10 +263,7 @@ def __init__(
if profile_and_rewards is not None:
self.rewardUpdated.connect(self._reward_widget.updateValues)
# Initial size
if profile_and_rewards is None:
self.resize(xlim * scale * 1.6, ylim * scale * 1.75)
else:
self.resize(xlim * scale * 1.6, ylim * scale * 1.4)
self.resize(int(xlim * scale * 1.6), int(ylim * scale * 1.4))
self._self_terminate = self_terminate

def _check_exit(self) -> None:
Expand Down

0 comments on commit a4d06fd

Please sign in to comment.