Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 38 additions & 10 deletions meent/on_jax/emsolver/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,10 @@ def solve_1d(self, wavelength, epx_conv_all, epy_conv_all, epz_conv_i_all):
else:
raise ValueError

# Track eps of layer below for same-material detection.
# Substrate eps uses conj(n)^2 to match layer convention.
eps_below = jnp.conj(self.n_bot) ** 2

# From the last layer
for layer_index in range(len(self.thickness))[::-1]:

Expand All @@ -310,12 +314,21 @@ def solve_1d(self, wavelength, epx_conv_all, epy_conv_all, epz_conv_i_all):
W, V, q = transfer_1d_2(self.pol, kx, epx_conv, epy_conv, epz_conv_i, self.type_complex,
self.perturbation, use_pinv=self.use_pinv)

# Detect if this layer has the same material as the layer below.
# For the bottom layer, compare with substrate eps.
# A uniform layer of the same material has diagonal eps_conv ≈ eps_below * I.
layer_eps_diag = jnp.diag(epx_conv).mean()
is_same = jnp.allclose(layer_eps_diag, eps_below, rtol=1e-3)

X, F, G, T, A_i, B = transfer_1d_3(k0, W, V, q, d, F, G, T, type_complex=self.type_complex,
use_pinv=self.use_pinv)
use_pinv=self.use_pinv, same_material=is_same)

layer_info = [epz_conv_i, W, V, q, d, A_i, B]
self.layer_info_list.append(layer_info)

# Update eps_below for next layer comparison
eps_below = layer_eps_diag

elif self.connecting_algo == 'SMM':
raise ValueError
# A, B, S_dict, Sg = scattering_1d_2(W, Wg, V, Vg, d, k0, Q, Sg)
Expand Down Expand Up @@ -362,6 +375,10 @@ def solve_1d_conical(self, wavelength, epx_conv_all, epy_conv_all, epz_conv_i_al
else:
raise ValueError

# Track eps of layer below for same-material detection.
# Substrate eps uses conj(n)^2 to match layer convention.
eps_below = jnp.conj(self.n_bot) ** 2

for layer_index in range(len(self.thickness))[::-1]:

epx_conv = epx_conv_all[layer_index]
Expand All @@ -371,30 +388,30 @@ def solve_1d_conical(self, wavelength, epx_conv_all, epy_conv_all, epz_conv_i_al
d = self.thickness[layer_index]

if self.connecting_algo == 'TMM':
# big_X, big_F, big_G, big_T, big_A_i, big_B, W_1, W_2, V_11, V_12, V_21, V_22, q_1, q_2 \
# = transfer_1d_conical_2(k0, Kx, ky, E_conv, E_conv_i, o_E_conv_i, ff, d,
# varphi, big_F, big_G, big_T,
# type_complex=self.type_complex, device=self.device)
W, V, q = transfer_1d_conical_2(kx, ky, epx_conv, epy_conv, epz_conv_i, type_complex=self.type_complex,
perturbation=self.perturbation, device=self.device,
use_pinv=self.use_pinv)

# Detect if this layer has the same material as the layer below.
layer_eps_diag = jnp.diag(epx_conv).mean()
is_same = jnp.allclose(layer_eps_diag, eps_below, rtol=1e-3)

big_X, big_F, big_G, big_T, big_A_i, big_B, \
= transfer_1d_conical_3(k0, W, V, q, d, varphi, big_F, big_G, big_T, type_complex=self.type_complex,
use_pinv=self.use_pinv)
use_pinv=self.use_pinv, same_material=is_same)

layer_info = [epz_conv_i, W, V, q, d, big_A_i, big_B]
self.layer_info_list.append(layer_info)

# Update eps_below for next layer comparison
eps_below = layer_eps_diag

elif self.connecting_algo == 'SMM':
raise ValueError
else:
raise ValueError

if self.connecting_algo == 'TMM':
# de_ri, de_ti, big_T1 = transfer_1d_conical_3(big_F, big_G, big_T, Z_I, Y_I, self.psi, self.theta, ff,
# delta_i0, k_I_z, k0, self.n_top, self.n_bot, k_II_z,
# type_complex=self.type_complex)
result, big_T1 = transfer_1d_conical_4(ff_x, ff_y, big_F, big_G, big_T, kz_top, kz_bot, self.psi,
self.theta, self.n_top, self.n_bot, type_complex=self.type_complex,
use_pinv=self.use_pinv)
Expand Down Expand Up @@ -431,6 +448,10 @@ def solve_2d(self, wavelength, epx_conv_all, epy_conv_all, epz_conv_i_all):
else:
raise ValueError

# Track eps of layer below for same-material detection.
# Substrate eps uses conj(n)^2 to match layer convention.
eps_below = jnp.conj(self.n_bot) ** 2

# From the last layer
for layer_index in range(len(self.thickness))[::-1]:

Expand All @@ -444,13 +465,20 @@ def solve_2d(self, wavelength, epx_conv_all, epy_conv_all, epz_conv_i_all):
W, V, q = transfer_2d_2(kx, ky, epx_conv, epy_conv, epz_conv_i, self.type_complex, self.perturbation,
use_pinv=self.use_pinv)

# Detect if this layer has the same material as the layer below.
layer_eps_diag = jnp.diag(epx_conv).mean()
is_same = jnp.allclose(layer_eps_diag, eps_below, rtol=1e-3)

big_X, big_F, big_G, big_T, big_A_i, big_B, \
= transfer_2d_3(k0, W, V, q, d, varphi, big_F, big_G, big_T, type_complex=self.type_complex,
use_pinv=self.use_pinv)
use_pinv=self.use_pinv, same_material=is_same)

layer_info = [epz_conv_i, W, V, q, d, big_A_i, big_B]
self.layer_info_list.append(layer_info)

# Update eps_below for next layer comparison
eps_below = layer_eps_diag

elif self.connecting_algo == 'SMM':
raise ValueError
# W, V, q = scattering_2d_wv(ff_xy, Kx, Ky, E_conv, o_E_conv, o_E_conv_i, E_conv_i)
Expand Down
6 changes: 3 additions & 3 deletions meent/on_jax/emsolver/convolution_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def to_conv_mat_vector(ucell_info_list, fto_x, fto_y, device=None, type_complex=

for i, ucell_info in enumerate(ucell_info_list):
ucell_layer, x_list, y_list = ucell_info
eps_matrix = ucell_layer ** 2
eps_matrix = jnp.conj(ucell_layer) ** 2

epz_conv = cfs2d(eps_matrix, x_list, y_list, 1, 1, fto_x, fto_y, type_complex)
epy_conv = cfs2d(eps_matrix, x_list, y_list, 1, 0, fto_x, fto_y, type_complex)
Expand All @@ -81,7 +81,7 @@ def to_conv_mat_raster_continuous(ucell, fto_x, fto_y, device=None, type_complex

for i, layer in enumerate(ucell):
n_compressed, x_list, y_list = cell_compression(layer, type_complex=type_complex)
eps_matrix = n_compressed ** 2
eps_matrix = jnp.conj(n_compressed) ** 2

epz_conv = cfs2d(eps_matrix, x_list, y_list, 1, 1, fto_x, fto_y, type_complex)
epy_conv = cfs2d(eps_matrix, x_list, y_list, 1, 0, fto_x, fto_y, type_complex)
Expand Down Expand Up @@ -120,7 +120,7 @@ def to_conv_mat_raster_discrete(ucell, fto_x, fto_y, device=None, type_complex=j
n = minimum_pattern_size_x // layer.shape[1]
layer = jnp.repeat(layer, n + 1, axis=1, total_repeat_length=layer.shape[1] * (n + 1))

eps_matrix = layer ** 2
eps_matrix = jnp.conj(layer) ** 2

epz_conv = dfs2d(eps_matrix, 1, 1, fto_x, fto_y, type_complex)
epy_conv = dfs2d(eps_matrix, 1, 0, fto_x, fto_y, type_complex)
Expand Down
Loading