-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathMPNN_model.py
170 lines (114 loc) · 5.19 KB
/
MPNN_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
import jax
import flax.linen as nn
import jax.numpy as jnp
from typing import Tuple, Callable
import numpy as np
from netket.utils.types import NNInitFunc, DType
from jax.nn.initializers import (
zeros,
ones,
lecun_normal,
normal
)
from distances import distance_matrix, dist_min_image, make_vec_periodic
class Phi(nn.Module):
"""
Message-passing layer. A single feed-forward neural network
"""
output_dim: int
widths: Tuple #= (16,)
hidden_lyrs: int #= 1
initializer: NNInitFunc = lecun_normal()
activation: Callable = nn.activation.gelu
out_lyr_activation: Callable = None
@nn.compact
def __call__(self, x):
#(n_samples, N, d)
#in_dim = x.shape[-1]
#APPLY HIDDEN LAYERS
for i in range(self.hidden_lyrs):
x = nn.Dense(features = self.widths[i], kernel_init = self.initializer, param_dtype=np.float64)(x)
x = self.activation(x)
x = nn.LayerNorm(param_dtype=np.float64,use_bias = False,use_scale=False)(x)
#APPLY LAST LAYER WITH OUTPUT DIMENSION REQUIRED
x = nn.Dense(features = self.output_dim, kernel_init = self.initializer, param_dtype=np.float64)(x)
#APPLY ACTIVATION CONDITIONALLY ON out_lyr_activation
if self.out_lyr_activation is not None:
x = self.out_activation(x)
return x
class MPNN(nn.Module):
'''Class for coordinate transformations with Message-Passing Neural Network
Attributes:
- L: Length of each dimension of the system
- graph_number: Number of graph transformations to apply to coords
- phi_out_dim: Output dimension to apply at each MLP
- initializer: Intialization function
- activation: Activation function for each MLP
- phi_hidden_lyrs: Number of hidden layers for each MLP
- phi_diths: Corresponding widths to each layer
'''
L: Tuple
graph_number: int
phi_out_dim: int
initializer: NNInitFunc = lecun_normal()
activation: Callable = nn.activation.gelu
phi_hidden_lyrs: int = 1
phi_widths: Tuple = (5,)
@nn.compact
def __call__(self, ri):
assert len(ri.shape) == 3
N_samples, N, sdim = ri.shape
L = jnp.array(self.L)
#creation of hidden nodes and edges
hi = self.param("hidden_state_nodes", self.initializer, (1, 1, self.phi_widths[0]), np.float64)
hij = self.param("hidden_state_edges", self.initializer, (1, 1, 1, self.phi_widths[0]), np.float64)
hi = jnp.tile(hi, (N_samples,N,1))
hij = jnp.tile(hij, (N_samples,N,N,1))
#Euclidean distance between vectors
dist = distance_matrix(ri, L, periodic = False)
#Periodic distance between vectors
rij = distance_matrix(ri, L, periodic = True)
#make position vector periodic
ri = make_vec_periodic(ri, L)
normij = jnp.linalg.norm(jnp.sin(jnp.pi*dist[...,:]/L) + jnp.eye(N)[..., None], axis=-1, keepdims=True)**2 * (
1. - jnp.eye(N)[..., None]) #NORM OF THE TRANSFORMED DISTANCE VECTORS
xi = jnp.concatenate((ri, hi), axis = -1)
xij = jnp.concatenate((rij, normij, hij), axis = -1)
for i in range(self.graph_number):
phi = Phi(output_dim = self.phi_out_dim, widths = self.phi_widths, hidden_lyrs = self.phi_hidden_lyrs)
f = Phi(output_dim = self.phi_out_dim, widths = self.phi_widths, hidden_lyrs = self.phi_hidden_lyrs)
g = Phi(output_dim = self.phi_out_dim, widths = self.phi_widths, hidden_lyrs = self.phi_hidden_lyrs)
nuij = phi(xij)
if i != self.graph_number-1:
xij = jnp.concatenate((rij, normij, f(jnp.concatenate((xij, nuij), axis=-1))), axis=-1)
#xi = jnp.concatenate((ri, g(jnp.concatenate((xi, jnp.sum(nuij, axis=-2)), axis=-1))), axis=-1)
xi = g(jnp.concatenate((xi, jnp.sum(nuij, axis=-2)), axis=-1))
return xi
class logpsi(nn.Module):
"""
Brings together MPNN and a simple feed-forward NN to model \ln(\psi)
"""
L: Tuple
sdim: int
graph_number: int
phi_out_dim: int
initializer: NNInitFunc = lecun_normal()
activation: Callable = nn.activation.gelu
phi_hidden_lyrs: int = 1
phi_widths: Tuple = (5,)
rho_hidden_lyrs: int = 1
rho_widths: Tuple = (5,)
@nn.compact
def __call__(self, x):
N = x.shape[-1] // self.sdim
x = x.reshape(-1, N, self.sdim)
mpnn = MPNN(self.L, self.graph_number, self.phi_out_dim, self.initializer, self.activation, self.phi_hidden_lyrs, self.phi_widths)
#transform coords x into MPNN-kind coords
x = mpnn(x)
for i in range(self.rho_hidden_lyrs):
x = nn.Dense(features = self.rho_widths[i], kernel_init = self.initializer, param_dtype = np.float64)(x)
x = self.activation(x)
x = nn.LayerNorm(param_dtype=np.float64,use_bias = False,use_scale=False)(x)
x = nn.Dense(features = 1, kernel_init = self.initializer, param_dtype = np.float64)(x)
x = jnp.sum(x, axis = -2)
return x.reshape(-1)