Skip to content

Commit e8f9a21

Browse files
committed
Changed NN for dep coef to Transformer
1 parent 48a1110 commit e8f9a21

File tree

7 files changed

+207
-34
lines changed

7 files changed

+207
-34
lines changed

docs/preparation/ambiguities.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,12 +113,12 @@ Under this change, we have that
113113

114114
.. math:: \cos 2\Phi_B' = \cos 2\Phi_B, \quad \sin 2\Phi_B' = \sin 2\Phi_B, \quad \cos \Phi_B' = \cos \Phi_B, \quad \sin \Phi_B' = \sin \Phi_B.
115115

116-
Making use of the previous relations between the angles wrt to the
116+
Making use of the previous relations between the angles wrt to the
117117
vertical and the LOS, we have to solve the following equation:
118118

119119
.. math:: \left( 3 \cos^2\theta_B'-1 \right) \sin^2 \Theta_B' = \left( 3 \cos^2\theta_B-1 \right) \sin^2 \Theta_B,
120120

121-
which can be written as:
121+
which can be written as:
122122

123123
.. math::
124124

examples/nonmpi/syn/caii_syn.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99

1010
# Test a single inversion in non-iterator mode
11-
mod = hazel.Model('../../configurations/conf_caii.ini', working_mode='synthesis', verbose=3, root='../../')
11+
mod = hazel.Model('../../configurations/conf_caii.ini', working_mode='synthesis', verbose=4, root='../../')
1212
mod.set_nlte(False)
1313
mod.synthesize()
1414

@@ -29,4 +29,4 @@
2929

3030
ax[1].legend()
3131

32-
pl.show()
32+
pl.show()

hazel/__init__.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,12 @@
1313
from .util import *
1414
from . import codes
1515

16-
try:
17-
# import torch
18-
# import torch_geometric
19-
from .graphnet import *
20-
from .forward_nn import *
21-
except:
22-
pass
16+
# try:
17+
# # import torch
18+
# # import torch_geometric
19+
# from .graphnet import *
20+
# from .forward_nn import *
21+
# except:
22+
# pass
23+
24+
from .forward_nn_transformer import *
72.5 MB
Binary file not shown.

hazel/forward_nn_transformer.py

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
import numpy as np
2+
import glob
3+
import torch
4+
import torch.nn as nn
5+
from sklearn import neighbors
6+
import logging
7+
8+
9+
class PositionalEncoding(nn.Module):
10+
def __init__(self, d_emb, norm=10000.0):
11+
"""
12+
Inputs
13+
d_model - Hidden dimensionality.
14+
"""
15+
super().__init__()
16+
self.d_emb = d_emb
17+
self.norm = norm
18+
19+
def forward(self, t):
20+
pe = torch.zeros(t.shape[0], t.shape[1], self.d_emb).to(t.device) # (B, T, D)
21+
div_term = torch.exp(torch.arange(0, self.d_emb, 2).float() * (-np.log(self.norm) / self.d_emb))[None, None, :].to(t.device) # (1, 1, D / 2)
22+
t = t.unsqueeze(2) # (B, 1, T)
23+
pe[:, :, 0::2] = torch.sin(t * div_term) # (B, T, D / 2)
24+
pe[:, :, 1::2] = torch.cos(t * div_term) # (B, T, D / 2)
25+
return pe # (B, T, D)
26+
27+
class TransformerModel(nn.Module):
28+
29+
def __init__(self, ninp, nemb, nout, nhead, nhid, nlayers, dropout=0.1, norm=1000.0):
30+
"""
31+
Transformer model for sequence to sequence learning
32+
33+
Args:
34+
ninp (_type_): input size
35+
nemb (_type_): embedding size
36+
nout (_type_): output size
37+
nhead (_type_): number of heads
38+
nhid (_type_): hidden layer size in feed forward network
39+
nlayers (_type_): number of layers
40+
dropout (float, optional): dropout probability. Defaults to 0.5.
41+
"""
42+
super(TransformerModel, self).__init__()
43+
44+
self.model_type = 'Transformer'
45+
46+
self.encoder = nn.Linear(ninp, nemb)
47+
48+
self.pos_encoder = PositionalEncoding(nemb, norm)
49+
50+
encoder_layers = nn.TransformerEncoderLayer(nemb, nhead, nhid, dropout, norm_first=True, batch_first=True)
51+
self.transformer_encoder = nn.TransformerEncoder(encoder_layers, nlayers, enable_nested_tensor=False)
52+
53+
self.nemb = nemb
54+
self.decoder = nn.Linear(nemb, nout)
55+
56+
self.init_weights()
57+
58+
def init_weights(self):
59+
60+
# Since TraansformerEncoder inputs a TransformerEncoderLayer, all layers will use exactly the same initialization
61+
# We undo this here
62+
for name, param in self.named_parameters():
63+
if 'weight' in name and param.data.dim() == 2:
64+
nn.init.kaiming_uniform_(param)
65+
66+
def forward(self, src, tau, src_mask):
67+
68+
# Get tau embedding
69+
tau_emb = self.pos_encoder(tau)
70+
71+
# Embed the input sequence into the embedding space and add the tau embedding
72+
x = self.encoder(src) + tau_emb
73+
74+
# Apply the transformer encoder
75+
x = self.transformer_encoder(x, src_key_padding_mask=src_mask)
76+
77+
# Apply the decoder to the output space
78+
x = self.decoder(x)
79+
80+
output = (~src_mask).float()[:, :, None] * x
81+
82+
return output
83+
84+
85+
class Forward(object):
86+
def __init__(self, gpu=0, checkpoint=None, readir=None, verbose=0):
87+
88+
self.logger = logging.getLogger("neural")
89+
self.logger.setLevel(logging.DEBUG)
90+
self.logger.handlers = []
91+
ch = logging.StreamHandler()
92+
formatter = logging.Formatter('%(asctime)s - %(message)s')
93+
ch.setFormatter(formatter)
94+
self.logger.addHandler(ch)
95+
96+
# Is a GPU available?
97+
self.cuda = torch.cuda.is_available()
98+
self.gpu = gpu
99+
self.device = torch.device("cpu") #f"cuda:{self.gpu}" if self.cuda else "cpu")
100+
101+
if (checkpoint is None):
102+
if readir is None:
103+
raise ValueError('Not checkpoint or read directory selected')
104+
files = glob.glob(readir + '*.pth')
105+
self.checkpoint = sorted(files)[-1]
106+
else:
107+
self.checkpoint = checkpoint
108+
109+
checkpoint = torch.load(self.checkpoint, map_location=lambda storage, loc: storage, weights_only=False)
110+
111+
self.hyperparameters = checkpoint['hyperparameters']
112+
self.predict_model = TransformerModel(ninp=self.hyperparameters['transformer']['n_input'],
113+
nemb=self.hyperparameters['transformer']['n_embedding'],
114+
nout=self.hyperparameters['transformer']['n_output'],
115+
nhead=self.hyperparameters['transformer']['n_heads'],
116+
nhid=self.hyperparameters['transformer']['n_hidden'],
117+
nlayers=self.hyperparameters['transformer']['n_layers'],
118+
norm=self.hyperparameters['transformer']['norm'],
119+
dropout=self.hyperparameters['transformer']['dropout']).to(self.device)
120+
self.predict_model.load_state_dict(checkpoint['state_dict'])
121+
122+
self.predict_model.eval()
123+
124+
if (verbose >= 1):
125+
npars = sum(p.numel() for p in self.predict_model.parameters() if p.requires_grad)
126+
tmp = self.checkpoint.split('/')
127+
self.logger.info(f' * Using neural checkpoint {tmp[-1]} on {self.device} - N. parameters = {npars}')
128+
129+
def predict(self, tau_all, ne_all, vturb_all, T_all, vlos_all):
130+
131+
tau = (np.log10(tau_all.astype('float32')) + 10.0) * 10.0
132+
vturb = vturb_all.astype('float32') / 1e3 - 6.0
133+
vlos = vlos_all.astype('float32') / 1e3
134+
T = np.log10(T_all.astype('float32')) - 3.8
135+
ne = np.log10(ne_all.astype('float32')) - 16.0
136+
137+
pars = np.concatenate([vturb[None, :], vlos[None, :], T[None, :], ne[None, :]], axis=0).T
138+
mask = np.zeros(len(tau)).astype('bool')
139+
140+
pars = torch.tensor(pars, dtype=torch.float32).to(self.device)
141+
tau = torch.tensor(tau, dtype=torch.float32).to(self.device)
142+
mask = torch.tensor(mask, dtype=torch.bool).to(self.device)
143+
144+
with torch.no_grad():
145+
self.pred_out = self.predict_model(pars[None, ...], tau[None, ...], mask[None, ...])
146+
147+
return self.pred_out[0, ...].cpu().numpy()

hazel/model.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -637,7 +637,7 @@ def add_photosphere(self, atmosphere):
637637
self.atmospheres[atm['name']].add_active_line(lines=lines, spectrum=self.spectrum[atm['spectral region']],
638638
wvl_range=np.array(wvl_range), verbose=self.verbose)
639639

640-
if (self.atmospheres[atm['name']].graphnet_nlte is not None):
640+
if (self.atmospheres[atm['name']].transformer_nlte is not None):
641641
self.set_nlte(True)
642642

643643
if ('ranges' in atm):
@@ -1334,6 +1334,13 @@ def set_nlte(self, option):
13341334
self.use_nlte = option
13351335
if (self.verbose >= 1):
13361336
self.logger.info('Setting NLTE for Ca II 8542 A to {0}'.format(self.use_nlte))
1337+
if (self.use_nlte):
1338+
for atmospheres in self.order_atmospheres:
1339+
for n, order in enumerate(atmospheres):
1340+
for k, atm in enumerate(order):
1341+
if (self.atmospheres[atm].type == 'photosphere'):
1342+
self.atmospheres[atm].load_nlte_model(verbose=self.verbose)
1343+
13371344

13381345
def synthesize(self, perturbation=False):
13391346
"""

hazel/photosphere.py

Lines changed: 39 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,15 @@
99
from hazel.exceptions import NumericalErrorSIR
1010
from hazel.transforms import transformed_to_physical, jacobian_transformation
1111

12-
try:
13-
from hazel.forward_nn import Forward
14-
except:
15-
pass
12+
# try:
13+
# from hazel.forward_nn import Forward
14+
# except:
15+
# pass
16+
17+
# try:
18+
from hazel.forward_nn_transformer import Forward
19+
# except:
20+
# pass
1621

1722

1823
__all__ = ['SIR_atmosphere']
@@ -27,7 +32,7 @@ def __init__(self, working_mode, name='', root='', verbose=0):
2732
self.ff = 1.0
2833
self.macroturbulence = np.zeros(1)
2934
self.working_mode = working_mode
30-
self.graphnet_nlte = None
35+
self.transformer_nlte = None
3136
self.root = root
3237

3338
self.parameters['T'] = None
@@ -175,15 +180,27 @@ def add_active_line(self, lines, spectrum, wvl_range, verbose):
175180
self.wvl_axis = spectrum.wavelength_axis[ind_low:ind_top+1]
176181
self.wvl_range = np.array([ind_low, ind_top+1])
177182

178-
# Check if Ca II 8542 is in the list of lines and instantiate the neural networks
183+
# Check if Ca II 8542 is in the list of lines and instantiate the neural networks
179184
if (self.nlte):
180185
if 301 in self.lines:
181-
if self.graphnet_nlte is None:
182-
path = str(__file__).split('/')
183-
checkpoint = '/'.join(path[0:-1])+'/data/20211114-131045_best.prd.pth'
184-
if (verbose >= 1):
185-
self.logger.info(' * Reading NLTE Neural Network')
186-
self.graphnet_nlte = Forward(checkpoint=checkpoint, verbose=verbose)
186+
self.load_nlte_model(verbose=verbose)
187+
188+
def load_nlte_model(self, verbose):
189+
if self.transformer_nlte is None:
190+
# path = str(__file__).split('/')
191+
# checkpoint = '/'.join(path[0:-1])+'/data/20211114-131045_best.prd.pth'
192+
# if (verbose >= 1):
193+
# self.logger.info(' * Reading NLTE Neural Network')
194+
# self.graphnet_nlte = Forward(checkpoint=checkpoint, verbose=verbose)
195+
196+
path = str(__file__).split('/')
197+
checkpoint = '/'.join(path[0:-1])+'/data/2024-09-13-11_17_34.best.pth'
198+
if (verbose >= 1):
199+
self.logger.info(' * Reading NLTE Transformer Neural Network')
200+
self.transformer_nlte = Forward(checkpoint=checkpoint, verbose=verbose)
201+
202+
self.nlte = True
203+
187204

188205
def interpolate_nodes(self, log_tau, reference, nodes, nodes_location):
189206
"""
@@ -662,10 +679,10 @@ def synthesize(self, stokes_in, returnRF=False, nlte=False):
662679
self.Pe = sir_code.hydroeq(self.log_tau, self.parameters['T'],
663680
self.Pe, 1e5*self.parameters['vmic'], 1e5*self.parameters['v'], self.parameters['Bx'], self.parameters['By'],
664681
self.parameters['Bz'])
665-
682+
666683
# Check if the line is 8542 and we want NLTE. If that is the case, then evaluate the
667684
# neural network to return the departure coefficients
668-
if (nlte):
685+
if (nlte):
669686
if (self.nlte):
670687
dif = (self.parameters['T'] - self.t_old)
671688
if (np.max(dif) > self.t_change_departure):
@@ -674,15 +691,15 @@ def synthesize(self, stokes_in, returnRF=False, nlte=False):
674691
if (self.verbose >= 4):
675692
self.logger.info(' - NLTE neural oracle')
676693
n = len(self.log_tau)
677-
tau = [10.0**self.log_tau[::-1]]
694+
tau = 10.0**self.log_tau[::-1]
678695
ne = self.Pe / (1.381e-16 * self.parameters['T'])
679-
ne = [ne[::-1] * 1e6] # in m^-3
680-
tt = [self.parameters['T'][::-1]]
681-
vturb = [self.parameters['vmic'][::-1] * 1e3] # in m/s
682-
vlos = [self.parameters['v'][::-1] * 1e3] # in m/s
683-
prediction = self.graphnet_nlte.predict(tau, ne, vturb, tt, vlos)
684-
self.departure[0, i, :] = 10.0**prediction[0][::-1, 2]
685-
self.departure[1, i, :] = 10.0**prediction[0][::-1, 4]
696+
ne = ne[::-1] * 1e6 # in m^-3
697+
tt = self.parameters['T'][::-1]
698+
vturb = self.parameters['vmic'][::-1] * 1e3 # in m/s
699+
vlos = self.parameters['v'][::-1] * 1e3 # in m/s
700+
prediction = self.transformer_nlte.predict(tau, ne, vturb, tt, vlos)
701+
self.departure[0, i, :] = 10.0**prediction[::-1, 2]
702+
self.departure[1, i, :] = 10.0**prediction[::-1, 4]
686703

687704
self.t_old = self.parameters['T']
688705

0 commit comments

Comments
 (0)