-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
6295df9
commit 8d3ca2d
Showing
10 changed files
with
847 additions
and
810 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,8 @@ | ||
from .logic import InfidelityOperator | ||
|
||
from .overlap import InfidelityOperatorStandard, InfidelityUPsi | ||
from .overlap_U import InfidelityOperatorUPsi | ||
|
||
from netket.utils import _hide_submodules | ||
|
||
_hide_submodules(__name__, hide_folder=["overlap", "overlap_U"]) | ||
from .logic import InfidelityOperator | ||
|
||
from .overlap import InfidelityOperatorStandard, InfidelityUPsi | ||
from .overlap_U import InfidelityOperatorUPsi | ||
|
||
from netket.utils import _hide_submodules | ||
|
||
_hide_submodules(__name__, hide_folder=["overlap", "overlap_U"]) |
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,4 @@ | ||
from .operator import InfidelityOperatorStandard, InfidelityUPsi | ||
|
||
from . import expect | ||
from . import exact | ||
from .operator import InfidelityOperatorStandard, InfidelityUPsi | ||
|
||
from . import expect | ||
from . import exact |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,78 +1,78 @@ | ||
from functools import partial | ||
|
||
import jax.numpy as jnp | ||
import jax | ||
|
||
from netket import jax as nkjax | ||
from netket.vqs import FullSumState, expect, expect_and_grad | ||
from netket.utils import mpi | ||
from netket.stats import Stats | ||
|
||
from .operator import InfidelityOperatorStandard | ||
|
||
|
||
@expect.dispatch | ||
def infidelity(vstate: FullSumState, op: InfidelityOperatorStandard): | ||
if op.hilbert != vstate.hilbert: | ||
raise TypeError("Hilbert spaces should match") | ||
if not isinstance(op.target, FullSumState): | ||
raise TypeError("Can only compute infidelity of exact states.") | ||
|
||
return infidelity_sampling_FullSumState( | ||
vstate._apply_fun, | ||
vstate.parameters, | ||
vstate.model_state, | ||
vstate._all_states, | ||
op.target.to_array(), | ||
return_grad=False, | ||
) | ||
|
||
|
||
@expect_and_grad.dispatch | ||
def infidelity( # noqa: F811 | ||
vstate: FullSumState, | ||
op: InfidelityOperatorStandard, | ||
*, | ||
mutable, | ||
): | ||
if op.hilbert != vstate.hilbert: | ||
raise TypeError("Hilbert spaces should match") | ||
if not isinstance(op.target, FullSumState): | ||
raise TypeError("Can only compute infidelity of exact states.") | ||
|
||
return infidelity_sampling_FullSumState( | ||
vstate._apply_fun, | ||
vstate.parameters, | ||
vstate.model_state, | ||
vstate._all_states, | ||
op.target.to_array(), | ||
return_grad=True, | ||
) | ||
|
||
|
||
@partial(jax.jit, static_argnames=("afun", "return_grad")) | ||
def infidelity_sampling_FullSumState( | ||
afun, | ||
params, | ||
model_state, | ||
sigma, | ||
state_t, | ||
return_grad, | ||
): | ||
def expect_fun(params): | ||
state = jnp.exp(afun({"params": params, **model_state}, sigma)) | ||
state = state / jnp.sqrt(jnp.sum(jnp.abs(state) ** 2)) | ||
return jnp.abs(state.T.conj() @ state_t) ** 2 | ||
|
||
if not return_grad: | ||
F = expect_fun(params) | ||
return Stats(mean=1 - F, error_of_mean=0.0, variance=0.0) | ||
|
||
F, F_vjp_fun = nkjax.vjp(expect_fun, params, conjugate=True) | ||
|
||
F_grad = F_vjp_fun(jnp.ones_like(F))[0] | ||
F_grad = jax.tree_util.tree_map(lambda x: mpi.mpi_mean_jax(x)[0], F_grad) | ||
I_grad = jax.tree_util.tree_map(lambda x: -x, F_grad) | ||
I_stats = Stats(mean=1 - F, error_of_mean=0.0, variance=0.0) | ||
|
||
return I_stats, I_grad | ||
from functools import partial | ||
|
||
import jax.numpy as jnp | ||
import jax | ||
|
||
from netket import jax as nkjax | ||
from netket.vqs import FullSumState, expect, expect_and_grad | ||
from netket.utils import mpi | ||
from netket.stats import Stats | ||
|
||
from .operator import InfidelityOperatorStandard | ||
|
||
|
||
@expect.dispatch | ||
def infidelity(vstate: FullSumState, op: InfidelityOperatorStandard): | ||
if op.hilbert != vstate.hilbert: | ||
raise TypeError("Hilbert spaces should match") | ||
if not isinstance(op.target, FullSumState): | ||
raise TypeError("Can only compute infidelity of exact states.") | ||
|
||
return infidelity_sampling_FullSumState( | ||
vstate._apply_fun, | ||
vstate.parameters, | ||
vstate.model_state, | ||
vstate._all_states, | ||
op.target.to_array(), | ||
return_grad=False, | ||
) | ||
|
||
|
||
@expect_and_grad.dispatch | ||
def infidelity( # noqa: F811 | ||
vstate: FullSumState, | ||
op: InfidelityOperatorStandard, | ||
*, | ||
mutable, | ||
): | ||
if op.hilbert != vstate.hilbert: | ||
raise TypeError("Hilbert spaces should match") | ||
if not isinstance(op.target, FullSumState): | ||
raise TypeError("Can only compute infidelity of exact states.") | ||
|
||
return infidelity_sampling_FullSumState( | ||
vstate._apply_fun, | ||
vstate.parameters, | ||
vstate.model_state, | ||
vstate._all_states, | ||
op.target.to_array(), | ||
return_grad=True, | ||
) | ||
|
||
|
||
@partial(jax.jit, static_argnames=("afun", "return_grad")) | ||
def infidelity_sampling_FullSumState( | ||
afun, | ||
params, | ||
model_state, | ||
sigma, | ||
state_t, | ||
return_grad, | ||
): | ||
def expect_fun(params): | ||
state = jnp.exp(afun({"params": params, **model_state}, sigma)) | ||
state = state / jnp.sqrt(jnp.sum(jnp.abs(state) ** 2)) | ||
return jnp.abs(state.T.conj() @ state_t) ** 2 | ||
|
||
if not return_grad: | ||
F = expect_fun(params) | ||
return Stats(mean=1 - F, error_of_mean=0.0, variance=0.0) | ||
|
||
F, F_vjp_fun = nkjax.vjp(expect_fun, params, conjugate=True) | ||
|
||
F_grad = F_vjp_fun(jnp.ones_like(F))[0] | ||
F_grad = jax.tree_util.tree_map(lambda x: mpi.mpi_mean_jax(x)[0], F_grad) | ||
I_grad = jax.tree_util.tree_map(lambda x: -x, F_grad) | ||
I_stats = Stats(mean=1 - F, error_of_mean=0.0, variance=0.0) | ||
|
||
return I_stats, I_grad |
Oops, something went wrong.