Skip to content

Commit

Permalink
Netket fidelity Ju
Browse files Browse the repository at this point in the history
  • Loading branch information
alleSini99 committed Oct 17, 2024
1 parent 6295df9 commit 8d3ca2d
Show file tree
Hide file tree
Showing 10 changed files with 847 additions and 810 deletions.
16 changes: 8 additions & 8 deletions netket_fidelity/infidelity/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from .logic import InfidelityOperator

from .overlap import InfidelityOperatorStandard, InfidelityUPsi
from .overlap_U import InfidelityOperatorUPsi

from netket.utils import _hide_submodules

_hide_submodules(__name__, hide_folder=["overlap", "overlap_U"])
from .logic import InfidelityOperator

from .overlap import InfidelityOperatorStandard, InfidelityUPsi
from .overlap_U import InfidelityOperatorUPsi

from netket.utils import _hide_submodules

_hide_submodules(__name__, hide_folder=["overlap", "overlap_U"])
376 changes: 189 additions & 187 deletions netket_fidelity/infidelity/logic.py

Large diffs are not rendered by default.

8 changes: 4 additions & 4 deletions netket_fidelity/infidelity/overlap/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .operator import InfidelityOperatorStandard, InfidelityUPsi

from . import expect
from . import exact
from .operator import InfidelityOperatorStandard, InfidelityUPsi

from . import expect
from . import exact
156 changes: 78 additions & 78 deletions netket_fidelity/infidelity/overlap/exact.py
Original file line number Diff line number Diff line change
@@ -1,78 +1,78 @@
from functools import partial

import jax.numpy as jnp
import jax

from netket import jax as nkjax
from netket.vqs import FullSumState, expect, expect_and_grad
from netket.utils import mpi
from netket.stats import Stats

from .operator import InfidelityOperatorStandard


@expect.dispatch
def infidelity(vstate: FullSumState, op: InfidelityOperatorStandard):
if op.hilbert != vstate.hilbert:
raise TypeError("Hilbert spaces should match")
if not isinstance(op.target, FullSumState):
raise TypeError("Can only compute infidelity of exact states.")

return infidelity_sampling_FullSumState(
vstate._apply_fun,
vstate.parameters,
vstate.model_state,
vstate._all_states,
op.target.to_array(),
return_grad=False,
)


@expect_and_grad.dispatch
def infidelity( # noqa: F811
vstate: FullSumState,
op: InfidelityOperatorStandard,
*,
mutable,
):
if op.hilbert != vstate.hilbert:
raise TypeError("Hilbert spaces should match")
if not isinstance(op.target, FullSumState):
raise TypeError("Can only compute infidelity of exact states.")

return infidelity_sampling_FullSumState(
vstate._apply_fun,
vstate.parameters,
vstate.model_state,
vstate._all_states,
op.target.to_array(),
return_grad=True,
)


@partial(jax.jit, static_argnames=("afun", "return_grad"))
def infidelity_sampling_FullSumState(
afun,
params,
model_state,
sigma,
state_t,
return_grad,
):
def expect_fun(params):
state = jnp.exp(afun({"params": params, **model_state}, sigma))
state = state / jnp.sqrt(jnp.sum(jnp.abs(state) ** 2))
return jnp.abs(state.T.conj() @ state_t) ** 2

if not return_grad:
F = expect_fun(params)
return Stats(mean=1 - F, error_of_mean=0.0, variance=0.0)

F, F_vjp_fun = nkjax.vjp(expect_fun, params, conjugate=True)

F_grad = F_vjp_fun(jnp.ones_like(F))[0]
F_grad = jax.tree_util.tree_map(lambda x: mpi.mpi_mean_jax(x)[0], F_grad)
I_grad = jax.tree_util.tree_map(lambda x: -x, F_grad)
I_stats = Stats(mean=1 - F, error_of_mean=0.0, variance=0.0)

return I_stats, I_grad
from functools import partial

import jax.numpy as jnp
import jax

from netket import jax as nkjax
from netket.vqs import FullSumState, expect, expect_and_grad
from netket.utils import mpi
from netket.stats import Stats

from .operator import InfidelityOperatorStandard


@expect.dispatch
def infidelity(vstate: FullSumState, op: InfidelityOperatorStandard):
if op.hilbert != vstate.hilbert:
raise TypeError("Hilbert spaces should match")
if not isinstance(op.target, FullSumState):
raise TypeError("Can only compute infidelity of exact states.")

return infidelity_sampling_FullSumState(
vstate._apply_fun,
vstate.parameters,
vstate.model_state,
vstate._all_states,
op.target.to_array(),
return_grad=False,
)


@expect_and_grad.dispatch
def infidelity( # noqa: F811
vstate: FullSumState,
op: InfidelityOperatorStandard,
*,
mutable,
):
if op.hilbert != vstate.hilbert:
raise TypeError("Hilbert spaces should match")
if not isinstance(op.target, FullSumState):
raise TypeError("Can only compute infidelity of exact states.")

return infidelity_sampling_FullSumState(
vstate._apply_fun,
vstate.parameters,
vstate.model_state,
vstate._all_states,
op.target.to_array(),
return_grad=True,
)


@partial(jax.jit, static_argnames=("afun", "return_grad"))
def infidelity_sampling_FullSumState(
afun,
params,
model_state,
sigma,
state_t,
return_grad,
):
def expect_fun(params):
state = jnp.exp(afun({"params": params, **model_state}, sigma))
state = state / jnp.sqrt(jnp.sum(jnp.abs(state) ** 2))
return jnp.abs(state.T.conj() @ state_t) ** 2

if not return_grad:
F = expect_fun(params)
return Stats(mean=1 - F, error_of_mean=0.0, variance=0.0)

F, F_vjp_fun = nkjax.vjp(expect_fun, params, conjugate=True)

F_grad = F_vjp_fun(jnp.ones_like(F))[0]
F_grad = jax.tree_util.tree_map(lambda x: mpi.mpi_mean_jax(x)[0], F_grad)
I_grad = jax.tree_util.tree_map(lambda x: -x, F_grad)
I_stats = Stats(mean=1 - F, error_of_mean=0.0, variance=0.0)

return I_stats, I_grad
Loading

0 comments on commit 8d3ca2d

Please sign in to comment.