-
Notifications
You must be signed in to change notification settings - Fork 0
/
cpu_vs_numba_vs_jax_simple.py
64 lines (50 loc) · 1.46 KB
/
cpu_vs_numba_vs_jax_simple.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
# Compare CPU vs numba vs JAX for a simple task of filling in a matrix
import numba as nb
import numpy as np
import time
import jax
import jax.numpy as jnp
def f_cpu(N, P):
W = np.zeros((N, P))
for i in range(W.shape[0]):
for j in range(W.shape[1]):
W[i, j] = i ** 2 + j ** 2
return W
@nb.njit(cache=True)
def f_numba(N, P):
W = np.zeros((N, P))
for i in range(W.shape[0]):
for j in range(W.shape[1]):
W[i, j] = i ** 2 + j ** 2
return W
def f_jax_ij(i, j):
return i ** 2 + j ** 2
def f_jax(N, P):
f_jax_vmap = jax.vmap(jax.vmap(f_jax_ij, in_axes=(None, 0)), in_axes=(0, None))
return f_jax_vmap(jnp.arange(N), jnp.arange(P))
if __name__=="__main__":
N = 10000
P = 10000
# CPU
start = time.time()
W_cpu = f_cpu(N, P)
t_cpu = time.time() - start
print(f"CPU time: {t_cpu:.3f}s")
# Run numba and JAX twice to use jitted code
# Numba
_ = f_numba(N, P)
start = time.time()
W_numba = f_numba(N, P)
t_numba = time.time() - start
print(f"Numba time: {t_numba:.3f}s")
# JAX
_ = f_jax(N, P)
start = time.time()
W_jax = f_jax(N, P)
W_jax.block_until_ready()
t_jax = time.time() - start
print(f"JAX time: {t_jax:.3f}s")
print(f"\nNumba / CPU speed up: {t_cpu / t_numba:.3f}x")
print(f"JAX / numba speed up: {t_numba / t_jax:.3f}x")
assert np.abs(W_cpu - W_numba).max() < 1e-4
assert np.abs(W_cpu - W_jax).max() < 1e-4