Skip to content

Commit 53d37a9

Browse files
committed
Commenting out old version of the autoencoder
1 parent 980762a commit 53d37a9

File tree

1 file changed

+81
-79
lines changed

1 file changed

+81
-79
lines changed

src/mace/mace.py

Lines changed: 81 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
import torchode as to # Lienen, M., & Günnemann, S. 2022, in The Symbiosis of Deep Learning and Differential Equations II, NeurIPS. https://openreview.net/forum?id=uiKVKTiUYB0
1414
import src.mace.autoencoder as ae
1515
import src.mace.latentODE as lODE
16-
from scipy.stats import gmean
1716
from time import time
1817

1918

@@ -199,87 +198,90 @@ def forward(self, n_0, p, tstep):
199198
## ---------- OLD VERSION OF THE SOLVER CLASS ---------- ##
200199
## This class is compatible with an older version of the autoencoder
201200

202-
class Solver_old(nn.Module):
203-
'''
204-
The Solver class presents the architecture of MACE.
205-
Components:
206-
1) Encoder; neural network with adjustable amount of nodes and layers
207-
2) Neural ODE; ODE given by function g, with trainable elements
208-
3) Decoder; neural network with adjustable amount of nodes and layers
209-
210-
'''
211-
def __init__(self, p_dim, z_dim, DEVICE, n_dim=466, g_nn = False, atol = 1e-5, rtol = 1e-2):
212-
super(Solver_old, self).__init__() # type: ignore
213-
214-
self.status_train = list()
215-
self.status_test = list()
216-
217-
self.z_dim = z_dim
218-
self.n_dim = n_dim
219-
self.DEVICE = DEVICE
220-
self.g_nn = g_nn
221-
222-
## Setting the neural ODE
223-
input_ae_dim = n_dim
224-
if not self.g_nn:
225-
self.g = lODE.G(z_dim)
226-
input_ae_dim = input_ae_dim+p_dim
227-
self.odeterm = to.ODETerm(self.g, with_args=False)
228-
if self.g_nn:
229-
self.g = lODE.Gnn(p_dim, z_dim)
230-
self.odeterm = to.ODETerm(self.g, with_args=True)
231-
232-
self.step_method = to.Dopri5(term=self.odeterm)
233-
self.step_size_controller = to.IntegralController(atol=atol, rtol=rtol, term=self.odeterm)
234-
self.adjoint = to.AutoDiffAdjoint(self.step_method, self.step_size_controller).to(self.DEVICE) # type: ignore
235-
236-
self.jit_solver = torch.compile(self.adjoint)
237-
238-
## Setting the autoencoder (enocder + decoder)
239-
hidden_ae_dim = int(gmean([input_ae_dim, z_dim]))
240-
self.encoder = ae.Encoder_old(input_dim=input_ae_dim, hidden_dim=hidden_ae_dim, latent_dim=z_dim)
241-
self.decoder = ae.Decoder_old(latent_dim=z_dim , hidden_dim=hidden_ae_dim, output_dim=n_dim)
242201

243-
def set_status(self, status, phase):
244-
if phase == 'train':
245-
self.status_train.append(status)
246-
elif phase == 'test':
247-
self.status_test.append(status)
248-
249-
def get_status(self, phase):
250-
if phase == 'train':
251-
return np.array(self.status_train)
252-
elif phase == 'test':
253-
return np.array(self.status_test)
254-
255-
256-
def forward(self, n_0, p, tstep):
257-
'''
258-
Forward function giving the workflow of the MACE architecture.
259-
'''
260-
261-
x_0 = n_0 ## use NN version of G
262-
if not self.g_nn: ## DON'T use NN version of G
263-
## Ravel the abundances n_0 and physical input p to x_0
264-
x_0 = torch.cat((p, n_0), axis=-1) # type: ignore
265-
266-
## Encode x_0, returning the encoded z_0 in latent space
267-
z_0 = self.encoder(x_0)
202+
# from scipy.stats import gmean
203+
204+
# class Solver_old(nn.Module):
205+
# '''
206+
# The Solver class presents the architecture of MACE.
207+
# Components:
208+
# 1) Encoder; neural network with adjustable amount of nodes and layers
209+
# 2) Neural ODE; ODE given by function g, with trainable elements
210+
# 3) Decoder; neural network with adjustable amount of nodes and layers
211+
212+
# '''
213+
# def __init__(self, p_dim, z_dim, DEVICE, n_dim=466, g_nn = False, atol = 1e-5, rtol = 1e-2):
214+
# super(Solver_old, self).__init__() # type: ignore
215+
216+
# self.status_train = list()
217+
# self.status_test = list()
218+
219+
# self.z_dim = z_dim
220+
# self.n_dim = n_dim
221+
# self.DEVICE = DEVICE
222+
# self.g_nn = g_nn
223+
224+
# ## Setting the neural ODE
225+
# input_ae_dim = n_dim
226+
# if not self.g_nn:
227+
# self.g = lODE.G(z_dim)
228+
# input_ae_dim = input_ae_dim+p_dim
229+
# self.odeterm = to.ODETerm(self.g, with_args=False)
230+
# if self.g_nn:
231+
# self.g = lODE.Gnn(p_dim, z_dim)
232+
# self.odeterm = to.ODETerm(self.g, with_args=True)
233+
234+
# self.step_method = to.Dopri5(term=self.odeterm)
235+
# self.step_size_controller = to.IntegralController(atol=atol, rtol=rtol, term=self.odeterm)
236+
# self.adjoint = to.AutoDiffAdjoint(self.step_method, self.step_size_controller).to(self.DEVICE) # type: ignore
237+
238+
# self.jit_solver = torch.compile(self.adjoint)
239+
240+
# ## Setting the autoencoder (enocder + decoder)
241+
# hidden_ae_dim = int(gmean([input_ae_dim, z_dim]))
242+
# self.encoder = ae.Encoder_old(input_dim=input_ae_dim, hidden_dim=hidden_ae_dim, latent_dim=z_dim)
243+
# self.decoder = ae.Decoder_old(latent_dim=z_dim , hidden_dim=hidden_ae_dim, output_dim=n_dim)
244+
245+
# def set_status(self, status, phase):
246+
# if phase == 'train':
247+
# self.status_train.append(status)
248+
# elif phase == 'test':
249+
# self.status_test.append(status)
250+
251+
# def get_status(self, phase):
252+
# if phase == 'train':
253+
# return np.array(self.status_train)
254+
# elif phase == 'test':
255+
# return np.array(self.status_test)
256+
257+
258+
# def forward(self, n_0, p, tstep):
259+
# '''
260+
# Forward function giving the workflow of the MACE architecture.
261+
# '''
262+
263+
# x_0 = n_0 ## use NN version of G
264+
# if not self.g_nn: ## DON'T use NN version of G
265+
# ## Ravel the abundances n_0 and physical input p to x_0
266+
# x_0 = torch.cat((p, n_0), axis=-1) # type: ignore
267+
268+
# ## Encode x_0, returning the encoded z_0 in latent space
269+
# z_0 = self.encoder(x_0)
268270

269-
## Create initial value problem
270-
problem = to.InitialValueProblem(
271-
y0 = z_0.to(self.DEVICE), ## "view" is om met de batches om te gaan
272-
t_eval = tstep.view(z_0.shape[0],-1).to(self.DEVICE),
273-
)
271+
# ## Create initial value problem
272+
# problem = to.InitialValueProblem(
273+
# y0 = z_0.to(self.DEVICE), ## "view" is om met de batches om te gaan
274+
# t_eval = tstep.view(z_0.shape[0],-1).to(self.DEVICE),
275+
# )
274276

275-
## Solve initial value problem. Details are set in the __init__() of this class.
276-
solution = self.jit_solver.solve(problem, args=p)
277-
z_s = solution.ys.view(-1, self.z_dim) ## want batches
277+
# ## Solve initial value problem. Details are set in the __init__() of this class.
278+
# solution = self.jit_solver.solve(problem, args=p)
279+
# z_s = solution.ys.view(-1, self.z_dim) ## want batches
278280

279-
## Decode the resulting values from latent space z_s back to physical space
280-
n_s_ravel = self.decoder(z_s)
281+
# ## Decode the resulting values from latent space z_s back to physical space
282+
# n_s_ravel = self.decoder(z_s)
281283

282-
## Reshape correctly
283-
n_s = n_s_ravel.reshape(1,tstep.shape[-1], self.n_dim)
284+
# ## Reshape correctly
285+
# n_s = n_s_ravel.reshape(1,tstep.shape[-1], self.n_dim)
284286

285-
return n_s, z_s, solution.status
287+
# return n_s, z_s, solution.status

0 commit comments

Comments
 (0)