13
13
import torchode as to # Lienen, M., & Günnemann, S. 2022, in The Symbiosis of Deep Learning and Differential Equations II, NeurIPS. https://openreview.net/forum?id=uiKVKTiUYB0
14
14
import src .mace .autoencoder as ae
15
15
import src .mace .latentODE as lODE
16
- from scipy .stats import gmean
17
16
from time import time
18
17
19
18
@@ -199,87 +198,90 @@ def forward(self, n_0, p, tstep):
199
198
## ---------- OLD VERSION OF THE SOLVER CLASS ---------- ##
200
199
## This class is compatible with an older version of the autoencoder
201
200
202
- class Solver_old (nn .Module ):
203
- '''
204
- The Solver class presents the architecture of MACE.
205
- Components:
206
- 1) Encoder; neural network with adjustable amount of nodes and layers
207
- 2) Neural ODE; ODE given by function g, with trainable elements
208
- 3) Decoder; neural network with adjustable amount of nodes and layers
209
-
210
- '''
211
- def __init__ (self , p_dim , z_dim , DEVICE , n_dim = 466 , g_nn = False , atol = 1e-5 , rtol = 1e-2 ):
212
- super (Solver_old , self ).__init__ () # type: ignore
213
-
214
- self .status_train = list ()
215
- self .status_test = list ()
216
-
217
- self .z_dim = z_dim
218
- self .n_dim = n_dim
219
- self .DEVICE = DEVICE
220
- self .g_nn = g_nn
221
-
222
- ## Setting the neural ODE
223
- input_ae_dim = n_dim
224
- if not self .g_nn :
225
- self .g = lODE .G (z_dim )
226
- input_ae_dim = input_ae_dim + p_dim
227
- self .odeterm = to .ODETerm (self .g , with_args = False )
228
- if self .g_nn :
229
- self .g = lODE .Gnn (p_dim , z_dim )
230
- self .odeterm = to .ODETerm (self .g , with_args = True )
231
-
232
- self .step_method = to .Dopri5 (term = self .odeterm )
233
- self .step_size_controller = to .IntegralController (atol = atol , rtol = rtol , term = self .odeterm )
234
- self .adjoint = to .AutoDiffAdjoint (self .step_method , self .step_size_controller ).to (self .DEVICE ) # type: ignore
235
-
236
- self .jit_solver = torch .compile (self .adjoint )
237
-
238
- ## Setting the autoencoder (enocder + decoder)
239
- hidden_ae_dim = int (gmean ([input_ae_dim , z_dim ]))
240
- self .encoder = ae .Encoder_old (input_dim = input_ae_dim , hidden_dim = hidden_ae_dim , latent_dim = z_dim )
241
- self .decoder = ae .Decoder_old (latent_dim = z_dim , hidden_dim = hidden_ae_dim , output_dim = n_dim )
242
201
243
- def set_status (self , status , phase ):
244
- if phase == 'train' :
245
- self .status_train .append (status )
246
- elif phase == 'test' :
247
- self .status_test .append (status )
248
-
249
- def get_status (self , phase ):
250
- if phase == 'train' :
251
- return np .array (self .status_train )
252
- elif phase == 'test' :
253
- return np .array (self .status_test )
254
-
255
-
256
- def forward (self , n_0 , p , tstep ):
257
- '''
258
- Forward function giving the workflow of the MACE architecture.
259
- '''
260
-
261
- x_0 = n_0 ## use NN version of G
262
- if not self .g_nn : ## DON'T use NN version of G
263
- ## Ravel the abundances n_0 and physical input p to x_0
264
- x_0 = torch .cat ((p , n_0 ), axis = - 1 ) # type: ignore
265
-
266
- ## Encode x_0, returning the encoded z_0 in latent space
267
- z_0 = self .encoder (x_0 )
202
+ # from scipy.stats import gmean
203
+
204
+ # class Solver_old(nn.Module):
205
+ # '''
206
+ # The Solver class presents the architecture of MACE.
207
+ # Components:
208
+ # 1) Encoder; neural network with adjustable amount of nodes and layers
209
+ # 2) Neural ODE; ODE given by function g, with trainable elements
210
+ # 3) Decoder; neural network with adjustable amount of nodes and layers
211
+
212
+ # '''
213
+ # def __init__(self, p_dim, z_dim, DEVICE, n_dim=466, g_nn = False, atol = 1e-5, rtol = 1e-2):
214
+ # super(Solver_old, self).__init__() # type: ignore
215
+
216
+ # self.status_train = list()
217
+ # self.status_test = list()
218
+
219
+ # self.z_dim = z_dim
220
+ # self.n_dim = n_dim
221
+ # self.DEVICE = DEVICE
222
+ # self.g_nn = g_nn
223
+
224
+ # ## Setting the neural ODE
225
+ # input_ae_dim = n_dim
226
+ # if not self.g_nn:
227
+ # self.g = lODE.G(z_dim)
228
+ # input_ae_dim = input_ae_dim+p_dim
229
+ # self.odeterm = to.ODETerm(self.g, with_args=False)
230
+ # if self.g_nn:
231
+ # self.g = lODE.Gnn(p_dim, z_dim)
232
+ # self.odeterm = to.ODETerm(self.g, with_args=True)
233
+
234
+ # self.step_method = to.Dopri5(term=self.odeterm)
235
+ # self.step_size_controller = to.IntegralController(atol=atol, rtol=rtol, term=self.odeterm)
236
+ # self.adjoint = to.AutoDiffAdjoint(self.step_method, self.step_size_controller).to(self.DEVICE) # type: ignore
237
+
238
+ # self.jit_solver = torch.compile(self.adjoint)
239
+
240
+ # ## Setting the autoencoder (enocder + decoder)
241
+ # hidden_ae_dim = int(gmean([input_ae_dim, z_dim]))
242
+ # self.encoder = ae.Encoder_old(input_dim=input_ae_dim, hidden_dim=hidden_ae_dim, latent_dim=z_dim)
243
+ # self.decoder = ae.Decoder_old(latent_dim=z_dim , hidden_dim=hidden_ae_dim, output_dim=n_dim)
244
+
245
+ # def set_status(self, status, phase):
246
+ # if phase == 'train':
247
+ # self.status_train.append(status)
248
+ # elif phase == 'test':
249
+ # self.status_test.append(status)
250
+
251
+ # def get_status(self, phase):
252
+ # if phase == 'train':
253
+ # return np.array(self.status_train)
254
+ # elif phase == 'test':
255
+ # return np.array(self.status_test)
256
+
257
+
258
+ # def forward(self, n_0, p, tstep):
259
+ # '''
260
+ # Forward function giving the workflow of the MACE architecture.
261
+ # '''
262
+
263
+ # x_0 = n_0 ## use NN version of G
264
+ # if not self.g_nn: ## DON'T use NN version of G
265
+ # ## Ravel the abundances n_0 and physical input p to x_0
266
+ # x_0 = torch.cat((p, n_0), axis=-1) # type: ignore
267
+
268
+ # ## Encode x_0, returning the encoded z_0 in latent space
269
+ # z_0 = self.encoder(x_0)
268
270
269
- ## Create initial value problem
270
- problem = to .InitialValueProblem (
271
- y0 = z_0 .to (self .DEVICE ), ## "view" is om met de batches om te gaan
272
- t_eval = tstep .view (z_0 .shape [0 ],- 1 ).to (self .DEVICE ),
273
- )
271
+ # ## Create initial value problem
272
+ # problem = to.InitialValueProblem(
273
+ # y0 = z_0.to(self.DEVICE), ## "view" is om met de batches om te gaan
274
+ # t_eval = tstep.view(z_0.shape[0],-1).to(self.DEVICE),
275
+ # )
274
276
275
- ## Solve initial value problem. Details are set in the __init__() of this class.
276
- solution = self .jit_solver .solve (problem , args = p )
277
- z_s = solution .ys .view (- 1 , self .z_dim ) ## want batches
277
+ # ## Solve initial value problem. Details are set in the __init__() of this class.
278
+ # solution = self.jit_solver.solve(problem, args=p)
279
+ # z_s = solution.ys.view(-1, self.z_dim) ## want batches
278
280
279
- ## Decode the resulting values from latent space z_s back to physical space
280
- n_s_ravel = self .decoder (z_s )
281
+ # ## Decode the resulting values from latent space z_s back to physical space
282
+ # n_s_ravel = self.decoder(z_s)
281
283
282
- ## Reshape correctly
283
- n_s = n_s_ravel .reshape (1 ,tstep .shape [- 1 ], self .n_dim )
284
+ # ## Reshape correctly
285
+ # n_s = n_s_ravel.reshape(1,tstep.shape[-1], self.n_dim)
284
286
285
- return n_s , z_s , solution .status
287
+ # return n_s, z_s, solution.status
0 commit comments