-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdistances.py
67 lines (45 loc) · 1.58 KB
/
distances.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import jax
from jax import numpy as jnp
def dist_min_image(x, L, sdim, norm = False):
''' computes distances following minimum image convention
Args:
- x: coords on which to compute distance
- L: size of the system
- sdim: spatial dimension
- norm: boolean to output norm or not
Returns:
- distances / norm of the distances
'''
n_particles = x.shape[0]//sdim
x = x.reshape(-1, sdim)
distances = (-x[jnp.newaxis, :, :] + x[:, jnp.newaxis, :])[
jnp.triu_indices(n_particles, 1)]
distances = jnp.remainder(distances[...,:] + L / 2.0, L) - L / 2.0
if norm:
return jnp.linalg.norm(distances, axis=-1)
else:
return distances
def make_vec_periodic(vec, L):
''' makes a vector periodic
Args:
- vec: vector to be made periodic
- L: size of the system
Returns:
- periodic version of the vector
'''
periodic = jnp.concatenate((jnp.sin(2.*jnp.pi*vec[...,:]/L), jnp.cos(2.*jnp.pi*vec[...,:]/L)), axis = -1)
return periodic
def distance_matrix(x, L, periodic = True):
''' computes distances, optionally returning the trigonometric version
Args:
- x: coords on which to compute distance
- L: size of the system
- periodic: boolean to output periodic distances or not
Returns:
- distances / norm of the distances
'''
rij = x[..., :, jnp.newaxis, :] - x[..., jnp.newaxis, :, :]
if periodic:
return make_vec_periodic(rij, L)
else:
return rij