Skip to content

Commit 092db70

Browse files
committed
doc: add multidevice algorithm example
1 parent 1d65b60 commit 092db70

File tree

1 file changed

+272
-0
lines changed

1 file changed

+272
-0
lines changed
Lines changed: 272 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,272 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": 2,
6+
"metadata": {},
7+
"outputs": [],
8+
"source": [
9+
"from typing import Optional\n",
10+
"\n",
11+
"import jax\n",
12+
"import jax.numpy as jnp\n",
13+
"\n",
14+
"from evox import Algorithm, dataclass, pytree_field, problems, workflows, monitors, use_state\n",
15+
"from evox.core.distributed import ShardingType\n",
16+
"from evox.utils import *"
17+
]
18+
},
19+
{
20+
"cell_type": "markdown",
21+
"metadata": {},
22+
"source": [
23+
"In this example, we consider the following simple setup:\n",
24+
"```\n",
25+
" Node1\n",
26+
" |\n",
27+
" +----+----+\n",
28+
" | | |\n",
29+
"GPU GPU GPU\n",
30+
"```\n",
31+
"Where we only have one node with multiple GPUs. The communication between the GPUs is done through the PCIe or NVLink.\n",
32+
"When running in a distributed setup, we need to make decisions on how to place the data on these GPUs."
33+
]
34+
},
35+
{
36+
"cell_type": "code",
37+
"execution_count": 3,
38+
"metadata": {},
39+
"outputs": [],
40+
"source": [
41+
"# The only changes:\n",
42+
"# Add the sharding metadata\n",
43+
"@dataclass\n",
44+
"class SpecialPSOState:\n",
45+
" population: jax.Array = pytree_field(sharding=ShardingType.SHARED_FIRST_DIM)\n",
46+
" velocity: jax.Array = pytree_field(sharding=ShardingType.SHARED_FIRST_DIM)\n",
47+
" fitness: jax.Array = pytree_field(sharding=ShardingType.SHARED_FIRST_DIM)\n",
48+
" local_best_location: jax.Array = pytree_field(sharding=ShardingType.SHARED_FIRST_DIM)\n",
49+
" local_best_fitness: jax.Array = pytree_field(sharding=ShardingType.SHARED_FIRST_DIM)\n",
50+
" global_best_location: jax.Array\n",
51+
" global_best_fitness: jax.Array\n",
52+
" key: jax.random.PRNGKey\n",
53+
"\n",
54+
"\n",
55+
"@dataclass\n",
56+
"class PSO(Algorithm):\n",
57+
" dim: jax.Array = pytree_field(static=True, init=False)\n",
58+
" lb: jax.Array\n",
59+
" ub: jax.Array\n",
60+
" pop_size: jax.Array = pytree_field(static=True)\n",
61+
" w: jax.Array = pytree_field(default=0.6)\n",
62+
" phi_p: jax.Array = pytree_field(default=2.5)\n",
63+
" phi_g: jax.Array = pytree_field(default=0.8)\n",
64+
" mean: Optional[jax.Array] = pytree_field(default=None)\n",
65+
" stdev: Optional[jax.Array] = pytree_field(default=None)\n",
66+
" bound_method: str = pytree_field(static=True, default=\"clip\")\n",
67+
"\n",
68+
" def __post_init__(self):\n",
69+
" self.set_frozen_attr(\"dim\", self.lb.shape[0])\n",
70+
"\n",
71+
" def setup(self, key):\n",
72+
" state_key, init_pop_key, init_v_key = jax.random.split(key, 3)\n",
73+
" if self.mean is not None and self.stdev is not None:\n",
74+
" population = self.stdev * jax.random.normal(\n",
75+
" init_pop_key, shape=(self.pop_size, self.dim)\n",
76+
" )\n",
77+
" population = jnp.clip(population, self.lb, self.ub)\n",
78+
" velocity = self.stdev * jax.random.normal(\n",
79+
" init_v_key, shape=(self.pop_size, self.dim)\n",
80+
" )\n",
81+
" else:\n",
82+
" length = self.ub - self.lb\n",
83+
" population = jax.random.uniform(\n",
84+
" init_pop_key, shape=(self.pop_size, self.dim)\n",
85+
" )\n",
86+
" population = population * length + self.lb\n",
87+
" velocity = jax.random.uniform(init_v_key, shape=(self.pop_size, self.dim))\n",
88+
" velocity = velocity * length * 2 - length\n",
89+
"\n",
90+
" return SpecialPSOState(\n",
91+
" population=population,\n",
92+
" velocity=velocity,\n",
93+
" fitness=jnp.full((self.pop_size,), jnp.inf),\n",
94+
" local_best_location=population,\n",
95+
" local_best_fitness=jnp.full((self.pop_size,), jnp.inf),\n",
96+
" global_best_location=population[0],\n",
97+
" global_best_fitness=jnp.array([jnp.inf]),\n",
98+
" key=state_key,\n",
99+
" )\n",
100+
"\n",
101+
" def ask(self, state):\n",
102+
" return state.population, state\n",
103+
"\n",
104+
" def tell(self, state, fitness):\n",
105+
" key, rg_key, rp_key = jax.random.split(state.key, 3)\n",
106+
"\n",
107+
" rg = jax.random.uniform(rg_key, shape=(self.pop_size, self.dim))\n",
108+
" rp = jax.random.uniform(rp_key, shape=(self.pop_size, self.dim))\n",
109+
"\n",
110+
" compare = state.local_best_fitness > fitness\n",
111+
" local_best_location = jnp.where(\n",
112+
" compare[:, jnp.newaxis], state.population, state.local_best_location\n",
113+
" )\n",
114+
" local_best_fitness = jnp.minimum(state.local_best_fitness, fitness)\n",
115+
"\n",
116+
" global_best_location, global_best_fitness = min_by(\n",
117+
" [state.global_best_location[jnp.newaxis, :], state.population],\n",
118+
" [state.global_best_fitness, fitness],\n",
119+
" )\n",
120+
"\n",
121+
" global_best_fitness = jnp.atleast_1d(global_best_fitness)\n",
122+
"\n",
123+
" velocity = (\n",
124+
" self.w * state.velocity\n",
125+
" + self.phi_p * rp * (local_best_location - state.population)\n",
126+
" + self.phi_g * rg * (global_best_location - state.population)\n",
127+
" )\n",
128+
" population = state.population + velocity\n",
129+
"\n",
130+
" if self.bound_method == \"clip\":\n",
131+
" population = jnp.clip(population, self.lb, self.ub)\n",
132+
" velocity = jnp.clip(velocity, self.lb, self.ub)\n",
133+
" elif self.bound_method == \"reflect\":\n",
134+
" lower_bound_violation = population < self.lb\n",
135+
" upper_bound_violation = population > self.ub\n",
136+
"\n",
137+
" population = jnp.where(\n",
138+
" lower_bound_violation, 2 * self.lb - population, population\n",
139+
" )\n",
140+
" population = jnp.where(\n",
141+
" upper_bound_violation, 2 * self.ub - population, population\n",
142+
" )\n",
143+
" velocity = jnp.where(\n",
144+
" lower_bound_violation | upper_bound_violation, -velocity, velocity\n",
145+
" )\n",
146+
" # enforce the bounds in case the reflected particles are still out of bounds\n",
147+
" population = jnp.clip(population, self.lb, self.ub)\n",
148+
" velocity = jnp.clip(velocity, self.lb, self.ub)\n",
149+
"\n",
150+
" return state.replace(\n",
151+
" population=population,\n",
152+
" velocity=velocity,\n",
153+
" local_best_location=local_best_location,\n",
154+
" local_best_fitness=local_best_fitness,\n",
155+
" global_best_location=global_best_location,\n",
156+
" global_best_fitness=global_best_fitness,\n",
157+
" key=key,\n",
158+
" )\n"
159+
]
160+
},
161+
{
162+
"cell_type": "code",
163+
"execution_count": 4,
164+
"metadata": {},
165+
"outputs": [],
166+
"source": [
167+
"pso = PSO(\n",
168+
" lb=jnp.full(shape=(2,), fill_value=-32),\n",
169+
" ub=jnp.full(shape=(2,), fill_value=32),\n",
170+
" pop_size=100,\n",
171+
")\n",
172+
"ackley = problems.numerical.Ackley()"
173+
]
174+
},
175+
{
176+
"cell_type": "code",
177+
"execution_count": 5,
178+
"metadata": {},
179+
"outputs": [],
180+
"source": [
181+
"monitor = monitors.EvalMonitor()\n",
182+
"workflow = workflows.StdWorkflow(\n",
183+
" pso,\n",
184+
" ackley,\n",
185+
" monitors=[monitor],\n",
186+
")\n",
187+
"key = jax.random.PRNGKey(42)\n",
188+
"state = workflow.init(key)\n",
189+
"state = workflow.enable_multi_devices(state)"
190+
]
191+
},
192+
{
193+
"cell_type": "code",
194+
"execution_count": 6,
195+
"metadata": {},
196+
"outputs": [
197+
{
198+
"data": {
199+
"text/plain": [
200+
"State(StdWorkflowState(generation=NamedSharding(mesh=Mesh('POP': 2), spec=PartitionSpec(), memory_kind=device), first_step=True), {'algorithm': State(SpecialPSOState(population=NamedSharding(mesh=Mesh('POP': 2), spec=PartitionSpec('POP',), memory_kind=device), velocity=NamedSharding(mesh=Mesh('POP': 2), spec=PartitionSpec('POP',), memory_kind=device), fitness=NamedSharding(mesh=Mesh('POP': 2), spec=PartitionSpec('POP',), memory_kind=device), local_best_location=NamedSharding(mesh=Mesh('POP': 2), spec=PartitionSpec('POP',), memory_kind=device), local_best_fitness=NamedSharding(mesh=Mesh('POP': 2), spec=PartitionSpec('POP',), memory_kind=device), global_best_location=NamedSharding(mesh=Mesh('POP': 2), spec=PartitionSpec(), memory_kind=device), global_best_fitness=NamedSharding(mesh=Mesh('POP': 2), spec=PartitionSpec(), memory_kind=device), key=NamedSharding(mesh=Mesh('POP': 2), spec=PartitionSpec(), memory_kind=device)), {}),'monitors0': State(EvalMonitorState(first_step=True, latest_solution=None, latest_fitness=None, topk_solutions=None, topk_fitness=None), {}),'problem': State({}, {})})"
201+
]
202+
},
203+
"execution_count": 6,
204+
"metadata": {},
205+
"output_type": "execute_result"
206+
}
207+
],
208+
"source": [
209+
"# check if the state is correctly sharded\n",
210+
"jax.tree.map(lambda x: x.sharding, state)"
211+
]
212+
},
213+
{
214+
"cell_type": "code",
215+
"execution_count": 7,
216+
"metadata": {},
217+
"outputs": [],
218+
"source": [
219+
"# run the workflow for 50 steps\n",
220+
"for i in range(50):\n",
221+
" state = workflow.step(state)"
222+
]
223+
},
224+
{
225+
"cell_type": "code",
226+
"execution_count": 8,
227+
"metadata": {},
228+
"outputs": [],
229+
"source": [
230+
"best_solution, _state = use_state(monitor.get_best_solution)(state)"
231+
]
232+
},
233+
{
234+
"cell_type": "code",
235+
"execution_count": 9,
236+
"metadata": {},
237+
"outputs": [
238+
{
239+
"name": "stdout",
240+
"output_type": "stream",
241+
"text": [
242+
"[ 0.0002041 -0.00019218]\n"
243+
]
244+
}
245+
],
246+
"source": [
247+
"print(best_solution)"
248+
]
249+
}
250+
],
251+
"metadata": {
252+
"kernelspec": {
253+
"display_name": "venv",
254+
"language": "python",
255+
"name": "python3"
256+
},
257+
"language_info": {
258+
"codemirror_mode": {
259+
"name": "ipython",
260+
"version": 3
261+
},
262+
"file_extension": ".py",
263+
"mimetype": "text/x-python",
264+
"name": "python",
265+
"nbconvert_exporter": "python",
266+
"pygments_lexer": "ipython3",
267+
"version": "3.11.2"
268+
}
269+
},
270+
"nbformat": 4,
271+
"nbformat_minor": 2
272+
}

0 commit comments

Comments
 (0)