Skip to content

Commit bb237a3

Browse files
authored
Simplify: remove expect_2distr (#15)
@alleSini99 what do you think? Isn't this much simpler ? (We should check that it is as fast... but I think it is) cc @lgravina1997
1 parent a167b6d commit bb237a3

File tree

5 files changed

+44
-203
lines changed

5 files changed

+44
-203
lines changed

netket_fidelity/infidelity/overlap/expect.py

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from netket import jax as nkjax
88
from netket.utils import mpi
99

10-
from netket_fidelity.utils import expect_2distr
1110

1211
from .operator import InfidelityOperatorStandard
1312

@@ -76,7 +75,10 @@ def infidelity_sampling_MCState(
7675
σ_t = sigma_t.reshape(-1, N)
7776

7877
def expect_kernel(params):
79-
def kernel_fun(params, params_t, σ, σ_t):
78+
def kernel_fun(params_all, samples_all):
79+
params, params_t = params_all
80+
σ, σ_t = samples_all
81+
8082
W = {"params": params, **model_state}
8183
W_t = {"params": params_t, **model_state_t}
8284

@@ -91,14 +93,24 @@ def kernel_fun(params, params_t, σ, σ_t):
9193
lambda params, σ: 2 * afun_t({"params": params, **model_state_t}, σ).real
9294
)
9395

94-
return expect_2distr(
95-
log_pdf,
96-
log_pdf_t,
96+
def log_pdf_joint(params_all, samples_all):
97+
params, params_t = params_all
98+
σ, σ_t = samples_all
99+
log_pdf_vals = log_pdf(params, σ)
100+
log_pdf_t_vals = log_pdf_t(params_t, σ_t)
101+
return log_pdf_vals + log_pdf_t_vals
102+
103+
return nkjax.expect(
104+
log_pdf_joint,
97105
kernel_fun,
98-
params,
99-
params_t,
100-
σ,
101-
σ_t,
106+
(
107+
params,
108+
params_t,
109+
),
110+
(
111+
σ,
112+
σ_t,
113+
),
102114
n_chains=n_chains_t,
103115
)
104116

netket_fidelity/infidelity/overlap_U/exact.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,8 @@ def expect_fun(params):
8282
F, F_vjp_fun = nkjax.vjp(expect_fun, params, conjugate=True)
8383

8484
F_grad = F_vjp_fun(jnp.ones_like(F))[0]
85-
F_grad = jax.tree_map(lambda x: mpi.mpi_mean_jax(x)[0], F_grad)
86-
I_grad = jax.tree_map(lambda x: -x, F_grad)
85+
F_grad = jax.tree_util.tree_map(lambda x: mpi.mpi_mean_jax(x)[0], F_grad)
86+
I_grad = jax.tree_util.tree_map(lambda x: -x, F_grad)
8787
I_stats = Stats(mean=1 - F, error_of_mean=0.0, variance=0.0)
8888

8989
return I_stats, I_grad

netket_fidelity/infidelity/overlap_U/expect.py

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from netket.vqs import MCState, expect, expect_and_grad, get_local_kernel_arguments
1010
from netket.utils import mpi
1111

12-
from netket_fidelity.utils import expect_2distr
1312

1413
from .operator import InfidelityOperatorUPsi
1514

@@ -113,7 +112,10 @@ def infidelity_sampling_MCState(
113112
xp_t_ravel = jnp.vstack(xp_t_splitted)
114113

115114
def expect_kernel(params):
116-
def kernel_fun(params, params_t, σ, σ_t):
115+
def kernel_fun(params_all, samples_all):
116+
params, params_t = params_all
117+
σ, σ_t = samples_all
118+
117119
W = {"params": params, **model_state}
118120
W_t = {"params": params_t, **model_state_t}
119121

@@ -139,14 +141,24 @@ def kernel_fun(params, params_t, σ, σ_t):
139141
lambda params, σ: 2 * afun_t({"params": params, **model_state_t}, σ).real
140142
)
141143

142-
return expect_2distr(
143-
log_pdf,
144-
log_pdf_t,
144+
def log_pdf_joint(params_all, samples_all):
145+
params, params_t = params_all
146+
σ, σ_t = samples_all
147+
log_pdf_vals = log_pdf(params, σ)
148+
log_pdf_t_vals = log_pdf_t(params_t, σ_t)
149+
return log_pdf_vals + log_pdf_t_vals
150+
151+
return nkjax.expect(
152+
log_pdf_joint,
145153
kernel_fun,
146-
params,
147-
params_t,
148-
σ,
149-
σ_t,
154+
(
155+
params,
156+
params_t,
157+
),
158+
(
159+
σ,
160+
σ_t,
161+
),
150162
n_chains=n_chains_t,
151163
)
152164

netket_fidelity/utils/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
from .expect import expect_2distr
21
from .sampling_Ustate import make_logpsi_U_afun, _logpsi_U_fun
32

43
from netket.utils import _hide_submodules

netket_fidelity/utils/expect.py

Lines changed: 0 additions & 182 deletions
This file was deleted.

0 commit comments

Comments
 (0)