Skip to content

Commit

Permalink
dev: enable plotting 3d obj space with pf
Browse files Browse the repository at this point in the history
  • Loading branch information
BillHuang2001 committed Nov 4, 2024
1 parent 9d592fc commit 3a554be
Showing 1 changed file with 15 additions and 9 deletions.
24 changes: 15 additions & 9 deletions src/evox/vis_tools/plot.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import jax.numpy as jnp

from evox import use_state


def plot_dec_space(
population_history,
Expand Down Expand Up @@ -64,7 +62,6 @@ def plot_dec_space(
sliders = [
{
"currentvalue": {"prefix": "Generation: "},
"pad": {"t": 50},
"pad": {"b": 1, "t": 10},
"len": 0.8,
"x": 0.2,
Expand All @@ -82,7 +79,6 @@ def plot_dec_space(
"x": 1,
"y": 1,
"xanchor": "auto",
"xanchor": "auto",
},
margin={"l": 0, "r": 0, "t": 0, "b": 0},
sliders=sliders,
Expand Down Expand Up @@ -115,7 +111,6 @@ def plot_dec_space(
"frame": {"duration": 0, "redraw": False},
"mode": "immediate",
"transition": {"duration": 0},
"mode": "immediate",
},
],
"label": "Pause",
Expand Down Expand Up @@ -171,7 +166,6 @@ def plot_obj_space_1d_no_animation(fitness_history, **kwargs):
"x": 1,
"y": 1,
"xanchor": "auto",
"xanchor": "auto",
},
margin={"l": 0, "r": 0, "t": 0, "b": 0},
),
Expand Down Expand Up @@ -268,7 +262,6 @@ def plot_obj_space_1d_animation(fitness_history, **kwargs):
"x": 1,
"y": 1,
"xanchor": "auto",
"xanchor": "auto",
},
margin={"l": 0, "r": 0, "t": 0, "b": 0},
sliders=sliders,
Expand Down Expand Up @@ -397,7 +390,6 @@ def plot_obj_space_2d(fitness_history, problem_pf=None, sort_points=False, **kwa
"x": 1,
"y": 1,
"xanchor": "auto",
"xanchor": "auto",
},
margin={"l": 0, "r": 0, "t": 0, "b": 0},
sliders=sliders,
Expand Down Expand Up @@ -479,6 +471,17 @@ def plot_obj_space_3d(fitness_history, sort_points=False, problem_pf=None, **kwa

frames = []
steps = []

if problem_pf is not None:
pf_scatter = go.Scatter3d(
x=problem_pf[:, 0],
y=problem_pf[:, 1],
z=problem_pf[:, 2],
mode="markers",
marker={"color": "#FFA15A", "size": 2},
name="Pareto Front",
)

for i, fit in enumerate(fitness_history):
# it will make the animation look nicer
if sort_points:
Expand All @@ -492,7 +495,10 @@ def plot_obj_space_3d(fitness_history, sort_points=False, problem_pf=None, **kwa
mode="markers",
marker={"color": "#636EFA", "size": 2},
)
frames.append(go.Frame(data=[scatter], name=str(i)))
if problem_pf is not None:
frames.append(go.Frame(data=[pf_scatter, scatter], name=str(i)))
else:
frames.append(go.Frame(data=[scatter], name=str(i)))

step = {
"label": i,
Expand Down

0 comments on commit 3a554be

Please sign in to comment.