diff --git a/netket_fidelity/infidelity/__init__.py b/netket_fidelity/infidelity/__init__.py index 0e7e191..f761e3f 100644 --- a/netket_fidelity/infidelity/__init__.py +++ b/netket_fidelity/infidelity/__init__.py @@ -1,8 +1,8 @@ -from .logic import InfidelityOperator - -from .overlap import InfidelityOperatorStandard, InfidelityUPsi -from .overlap_U import InfidelityOperatorUPsi - -from netket.utils import _hide_submodules - -_hide_submodules(__name__, hide_folder=["overlap", "overlap_U"]) +from .logic import InfidelityOperator + +from .overlap import InfidelityOperatorStandard, InfidelityUPsi +from .overlap_U import InfidelityOperatorUPsi + +from netket.utils import _hide_submodules + +_hide_submodules(__name__, hide_folder=["overlap", "overlap_U"]) diff --git a/netket_fidelity/infidelity/logic.py b/netket_fidelity/infidelity/logic.py index aa0da51..73e76d9 100644 --- a/netket_fidelity/infidelity/logic.py +++ b/netket_fidelity/infidelity/logic.py @@ -1,187 +1,189 @@ -from typing import Optional - -from netket.operator import AbstractOperator, Adjoint -from netket.vqs import VariationalState, FullSumState -from netket.utils.types import DType - -from .overlap import InfidelityOperatorStandard, InfidelityUPsi -from .overlap_U import InfidelityOperatorUPsi - - -def InfidelityOperator( - target: VariationalState, - *, - U: AbstractOperator = None, - U_dagger: AbstractOperator = None, - cv_coeff: Optional[float] = None, - is_unitary: bool = False, - dtype: Optional[DType] = None, - sample_Upsi=False, -): - r""" - Operator I_op computing the infidelity I among two variational states - :math:`|\psi\rangle` and :math:`|\phi\rangle` as: - - .. math:: - - I = 1 - \frac{|⟨\Psi|\Phi⟩|^2 }{ ⟨\Psi|\Psi⟩ ⟨\Phi|\Phi⟩ } = 1 - \frac{⟨\Psi|\hat{I}_{op}|\Psi⟩ }{ ⟨\Psi|\Psi⟩ } - - where: - - .. math:: - - I_{op} = \frac {|\Phi\rangle\langle\Phi| }{ \langle\Phi|\Phi\rangle } - - The state :math:`|\phi\rangle` can be an autonomous state :math:`|\Phi\rangle = |\phi\rangle` - or an operator :math:`U` applied to it, namely - :math:`|\Phi\rangle = U|\phi\rangle`. :math:`I_{op}` is defined by the - state :math:`|\phi\rangle` (called target) and, possibly, by the operator - :math:`U`. If :math:`U` is not specified, it is assumed :math:`|\Phi\rangle = |\phi\rangle`. - - The Monte Carlo estimator of I is: - - .. math:: - - I = \mathbb{E}_{χ}[ I_{loc}(\sigma,\eta) ] = - \mathbb{E}_{χ}\left[\frac{⟨\sigma|\Phi⟩ ⟨\eta|\Psi⟩}{⟨σ|\Psi⟩ ⟨η|\Phi⟩}\right] - - where the sampled probability distribution :math:`χ` is defined as: - - .. math:: - - \chi(\sigma, \eta) = \frac{|\psi(\sigma)|^2 |\Phi(\eta)|^2}{ - \langle\Psi|\Psi\rangle \langle\Phi|\Phi\rangle}. - - In practice, since I is a real quantity, :math:`\rm{Re}[I_{loc}(\sigma,\eta)]` - is used. This estimator can be utilized both when :math:`|\Phi\rangle =|\phi\rangle` and - when :math:`|\Phi\rangle =U|\phi\rangle`, with :math:`U` a (unitary or non-unitary) operator. - In the second case, we have to sample from :math:`U|\phi\rangle` and this is implemented in - the function :class:`netket_fidelity.infidelity.InfidelityUPsi` . - - This works only with the operators provdided in the package. - We remark that sampling from :math:`U|\phi\rangle` requires to compute connected elements of - :math:`U` and so is more expensive than sampling from an autonomous state. - The choice of this estimator is specified by passing :code:`sample_Upsi=True`, - while the flag argument :code:`is_unitary` indicates whether :math:`U` is unitary or not. - - If :math:`U` is unitary, the following alternative estimator can be used: - - .. math:: - - I = \mathbb{E}_{χ'}\left[ I_{loc}(\sigma, \eta) \right] = - \mathbb{E}_{χ}\left[\frac{\langle\sigma|U|\phi\rangle \langle\eta|\psi\rangle}{ - \langle\sigma|U^{\dagger}|\psi\rangle ⟨\eta|\phi⟩} \right]. - - where the sampled probability distribution :math:`\chi` is defined as: - - .. math:: - - \chi'(\sigma, \eta) = \frac{|\psi(\sigma)|^2 |\phi(\eta)|^2}{ - \langle\Psi|\Psi\rangle \langle\phi|\phi\rangle}. - - This estimator is more efficient since it does not require to sample from - :math:`U|\phi\rangle`, but only from :math:`|\phi\rangle`. - This choice of the estimator is the default and it works only - with `is_unitary==True` (besides :code:`sample_Upsi=False` ). - When :math:`|\Phi⟩ = |\phi⟩` the two estimators coincides. - - To reduce the variance of the estimator, the Control Variates (CV) method can be applied. This consists - in modifying the estimator into: - - .. math:: - - I_{loc}^{CV} = \rm{Re}\left[I_{loc}(\sigma, \eta)\right] - c \left(|1 - I_{loc}(\sigma, \eta)^2| - 1\right) - - where :math:`c ∈ \mathbb{R}`. The constant c is chosen to minimize the variance of - :math:`I_{loc}^{CV}` as: - - .. math:: - - c* = \frac{\rm{Cov}_{χ}\left[ |1-I_{loc}|^2, \rm{Re}\left[1-I_{loc}\right]\right]}{ - \rm{Var}_{χ}\left[ |1-I_{loc}|^2\right] }, - - where :math:`\rm{Cov}\left\cdot, \cdot\right]` indicates the covariance and :math:`\rm{Var}\left[\cdot\right]` the variance. - In the relevant limit :math:`|\Psi⟩ \rightarrow|\Phi⟩`, we have :math:`c^\star \rightarrow -1/2`. The value :math:`-1/2` is - adopted as default value for c in the infidelity - estimator. To not apply CV, set c=0. - - Args: - target: target variational state :math:`|\phi⟩` . - U: operator :math:`\hat{U}`. - U_dagger: dagger operator :math:`\hat{U^\dagger}`. - cv_coeff: Control Variates coefficient c. - is_unitary: flag specifiying the unitarity of :math:`\hat{U}`. If True with - :code:`sample_Upsi=False`, the second estimator is used. - dtype: The dtype of the output of expectation value and gradient. - sample_Upsi: flag specifiying whether to sample from |ϕ⟩ or from U|ϕ⟩. If False with `is_unitary=False` , an error occurs. - - Returns: - Infidelity operator for which computing expected value and gradient. - - Examples: - - >>> import netket as nk - >>> import netket_fidelity as nkf - >>> - >>> hi = nk.hilbert.Spin(0.5, 4) - >>> sampler = nk.sampler.MetropolisLocal(hilbert=hi, n_chains_per_rank=16) - >>> model = nk.models.RBM(alpha=1, param_dtype=complex) - >>> target_vstate = nk.vqs.MCState(sampler=sampler, model=model, n_samples=100) - >>> - >>> # To optimise the overlap with |ϕ⟩ - >>> I_op = nkf.InfidelityOperator(target_vstate) - >>> - >>> # To optimise the overlap with U|ϕ⟩ by sampling from |ψ⟩ and |ϕ⟩ - >>> U = nkf.operator.Rx(0.3) - >>> I_op = nkf.InfidelityOperator(target_vstate, U=U, is_unitary=True) - >>> - >>> # To optimise the overlap with U|ϕ⟩ by sampling from |ψ⟩ and U|ϕ⟩ - >>> I_op = nkf.InfidelityOperator(target_vstate, U=U, sample_Upsi=True) - - """ - if U is None: - return InfidelityOperatorStandard(target, cv_coeff=cv_coeff, dtype=dtype) - else: - if U_dagger is None: - U_dagger = U.H - if isinstance(U_dagger, Adjoint): - raise TypeError( - "Must explicitly pass a jax-compatible operator as `U_dagger`. " - "You either did not pass `U_dagger` explicitly or you used `U.H` but should " - "use operators coming from `netket_fidelity`. " - ) - - if isinstance(target, FullSumState): - return InfidelityOperatorUPsi( - U, - target, - U_dagger=U_dagger, - cv_coeff=cv_coeff, - dtype=dtype, - ) - - if not is_unitary and not sample_Upsi: - raise ValueError( - "Non-unitary operators can only be handled by sampling from the state U|ψ⟩. " - "This is more expensive and disabled by default. " - "" - "If your operator is Unitary, please specify so by passing `is_unitary=True` as a " - "keyword argument. " - "" - "If your operator is not unitary, please specify `sample_Upsi=True` explicitly to" - "sample from that state. " - "You can also sample from U|ψ⟩ if your operator is unitary. " - "" - ) - - if sample_Upsi: - return InfidelityUPsi(U, target, cv_coeff=cv_coeff, dtype=dtype) - else: - return InfidelityOperatorUPsi( - U, - target, - U_dagger=U_dagger, - cv_coeff=cv_coeff, - dtype=dtype, - is_unitary=True, - ) +from typing import Optional + +from netket.operator import AbstractOperator, Adjoint +from netket.vqs import VariationalState, FullSumState +from netket.utils.types import DType + +from .overlap import InfidelityOperatorStandard, InfidelityUPsi +from .overlap_U import InfidelityOperatorUPsi + + +def InfidelityOperator( + target: VariationalState, + *, + U: AbstractOperator = None, + delta_tau: Optional[float] = None, + ham: Optional[AbstractOperator] = None, + U_dagger: AbstractOperator = None, + cv_coeff: Optional[float] = None, + is_unitary: bool = False, + dtype: Optional[DType] = None, + sample_Upsi=False, +): + r""" + Operator I_op computing the infidelity I among two variational states + :math:`|\psi\rangle` and :math:`|\phi\rangle` as: + + .. math:: + + I = 1 - \frac{|⟨\Psi|\Phi⟩|^2 }{ ⟨\Psi|\Psi⟩ ⟨\Phi|\Phi⟩ } = 1 - \frac{⟨\Psi|\hat{I}_{op}|\Psi⟩ }{ ⟨\Psi|\Psi⟩ } + + where: + + .. math:: + + I_{op} = \frac {|\Phi\rangle\langle\Phi| }{ \langle\Phi|\Phi\rangle } + + The state :math:`|\phi\rangle` can be an autonomous state :math:`|\Phi\rangle = |\phi\rangle` + or an operator :math:`U` applied to it, namely + :math:`|\Phi\rangle = U|\phi\rangle`. :math:`I_{op}` is defined by the + state :math:`|\phi\rangle` (called target) and, possibly, by the operator + :math:`U`. If :math:`U` is not specified, it is assumed :math:`|\Phi\rangle = |\phi\rangle`. + + The Monte Carlo estimator of I is: + + .. math:: + + I = \mathbb{E}_{χ}[ I_{loc}(\sigma,\eta) ] = + \mathbb{E}_{χ}\left[\frac{⟨\sigma|\Phi⟩ ⟨\eta|\Psi⟩}{⟨σ|\Psi⟩ ⟨η|\Phi⟩}\right] + + where the sampled probability distribution :math:`χ` is defined as: + + .. math:: + + \chi(\sigma, \eta) = \frac{|\psi(\sigma)|^2 |\Phi(\eta)|^2}{ + \langle\Psi|\Psi\rangle \langle\Phi|\Phi\rangle}. + + In practice, since I is a real quantity, :math:`\rm{Re}[I_{loc}(\sigma,\eta)]` + is used. This estimator can be utilized both when :math:`|\Phi\rangle =|\phi\rangle` and + when :math:`|\Phi\rangle =U|\phi\rangle`, with :math:`U` a (unitary or non-unitary) operator. + In the second case, we have to sample from :math:`U|\phi\rangle` and this is implemented in + the function :class:`netket_fidelity.infidelity.InfidelityUPsi` . + + This works only with the operators provdided in the package. + We remark that sampling from :math:`U|\phi\rangle` requires to compute connected elements of + :math:`U` and so is more expensive than sampling from an autonomous state. + The choice of this estimator is specified by passing :code:`sample_Upsi=True`, + while the flag argument :code:`is_unitary` indicates whether :math:`U` is unitary or not. + + If :math:`U` is unitary, the following alternative estimator can be used: + + .. math:: + + I = \mathbb{E}_{χ'}\left[ I_{loc}(\sigma, \eta) \right] = + \mathbb{E}_{χ}\left[\frac{\langle\sigma|U|\phi\rangle \langle\eta|\psi\rangle}{ + \langle\sigma|U^{\dagger}|\psi\rangle ⟨\eta|\phi⟩} \right]. + + where the sampled probability distribution :math:`\chi` is defined as: + + .. math:: + + \chi'(\sigma, \eta) = \frac{|\psi(\sigma)|^2 |\phi(\eta)|^2}{ + \langle\Psi|\Psi\rangle \langle\phi|\phi\rangle}. + + This estimator is more efficient since it does not require to sample from + :math:`U|\phi\rangle`, but only from :math:`|\phi\rangle`. + This choice of the estimator is the default and it works only + with `is_unitary==True` (besides :code:`sample_Upsi=False` ). + When :math:`|\Phi⟩ = |\phi⟩` the two estimators coincides. + + To reduce the variance of the estimator, the Control Variates (CV) method can be applied. This consists + in modifying the estimator into: + + .. math:: + + I_{loc}^{CV} = \rm{Re}\left[I_{loc}(\sigma, \eta)\right] - c \left(|1 - I_{loc}(\sigma, \eta)^2| - 1\right) + + where :math:`c ∈ \mathbb{R}`. The constant c is chosen to minimize the variance of + :math:`I_{loc}^{CV}` as: + + .. math:: + + c* = \frac{\rm{Cov}_{χ}\left[ |1-I_{loc}|^2, \rm{Re}\left[1-I_{loc}\right]\right]}{ + \rm{Var}_{χ}\left[ |1-I_{loc}|^2\right] }, + + where :math:`\rm{Cov}\left\cdot, \cdot\right]` indicates the covariance and :math:`\rm{Var}\left[\cdot\right]` the variance. + In the relevant limit :math:`|\Psi⟩ \rightarrow|\Phi⟩`, we have :math:`c^\star \rightarrow -1/2`. The value :math:`-1/2` is + adopted as default value for c in the infidelity + estimator. To not apply CV, set c=0. + + Args: + target: target variational state :math:`|\phi⟩` . + U: operator :math:`\hat{U}`. + U_dagger: dagger operator :math:`\hat{U^\dagger}`. + cv_coeff: Control Variates coefficient c. + is_unitary: flag specifiying the unitarity of :math:`\hat{U}`. If True with + :code:`sample_Upsi=False`, the second estimator is used. + dtype: The dtype of the output of expectation value and gradient. + sample_Upsi: flag specifiying whether to sample from |ϕ⟩ or from U|ϕ⟩. If False with `is_unitary=False` , an error occurs. + + Returns: + Infidelity operator for which computing expected value and gradient. + + Examples: + + >>> import netket as nk + >>> import netket_fidelity as nkf + >>> + >>> hi = nk.hilbert.Spin(0.5, 4) + >>> sampler = nk.sampler.MetropolisLocal(hilbert=hi, n_chains_per_rank=16) + >>> model = nk.models.RBM(alpha=1, param_dtype=complex) + >>> target_vstate = nk.vqs.MCState(sampler=sampler, model=model, n_samples=100) + >>> + >>> # To optimise the overlap with |ϕ⟩ + >>> I_op = nkf.InfidelityOperator(target_vstate) + >>> + >>> # To optimise the overlap with U|ϕ⟩ by sampling from |ψ⟩ and |ϕ⟩ + >>> U = nkf.operator.Rx(0.3) + >>> I_op = nkf.InfidelityOperator(target_vstate, U=U, is_unitary=True) + >>> + >>> # To optimise the overlap with U|ϕ⟩ by sampling from |ψ⟩ and U|ϕ⟩ + >>> I_op = nkf.InfidelityOperator(target_vstate, U=U, sample_Upsi=True) + + """ + if U is None: + return InfidelityOperatorStandard(target,delta_tau=delta_tau,ham=ham, cv_coeff=cv_coeff, dtype=dtype) + else: + if U_dagger is None: + U_dagger = U.H + if isinstance(U_dagger, Adjoint): + raise TypeError( + "Must explicitly pass a jax-compatible operator as `U_dagger`. " + "You either did not pass `U_dagger` explicitly or you used `U.H` but should " + "use operators coming from `netket_fidelity`. " + ) + + if isinstance(target, FullSumState): + return InfidelityOperatorUPsi( + U, + target, + U_dagger=U_dagger, + cv_coeff=cv_coeff, + dtype=dtype, + ) + + if not is_unitary and not sample_Upsi: + raise ValueError( + "Non-unitary operators can only be handled by sampling from the state U|ψ⟩. " + "This is more expensive and disabled by default. " + "" + "If your operator is Unitary, please specify so by passing `is_unitary=True` as a " + "keyword argument. " + "" + "If your operator is not unitary, please specify `sample_Upsi=True` explicitly to" + "sample from that state. " + "You can also sample from U|ψ⟩ if your operator is unitary. " + "" + ) + + if sample_Upsi: + return InfidelityUPsi(U, target, cv_coeff=cv_coeff, dtype=dtype) + else: + return InfidelityOperatorUPsi( + U, + target, + U_dagger=U_dagger, + cv_coeff=cv_coeff, + dtype=dtype, + is_unitary=True, + ) diff --git a/netket_fidelity/infidelity/overlap/__init__.py b/netket_fidelity/infidelity/overlap/__init__.py index bb3e351..e79800b 100644 --- a/netket_fidelity/infidelity/overlap/__init__.py +++ b/netket_fidelity/infidelity/overlap/__init__.py @@ -1,4 +1,4 @@ -from .operator import InfidelityOperatorStandard, InfidelityUPsi - -from . import expect -from . import exact +from .operator import InfidelityOperatorStandard, InfidelityUPsi + +from . import expect +from . import exact diff --git a/netket_fidelity/infidelity/overlap/exact.py b/netket_fidelity/infidelity/overlap/exact.py index aefe9bc..91c622a 100644 --- a/netket_fidelity/infidelity/overlap/exact.py +++ b/netket_fidelity/infidelity/overlap/exact.py @@ -1,78 +1,78 @@ -from functools import partial - -import jax.numpy as jnp -import jax - -from netket import jax as nkjax -from netket.vqs import FullSumState, expect, expect_and_grad -from netket.utils import mpi -from netket.stats import Stats - -from .operator import InfidelityOperatorStandard - - -@expect.dispatch -def infidelity(vstate: FullSumState, op: InfidelityOperatorStandard): - if op.hilbert != vstate.hilbert: - raise TypeError("Hilbert spaces should match") - if not isinstance(op.target, FullSumState): - raise TypeError("Can only compute infidelity of exact states.") - - return infidelity_sampling_FullSumState( - vstate._apply_fun, - vstate.parameters, - vstate.model_state, - vstate._all_states, - op.target.to_array(), - return_grad=False, - ) - - -@expect_and_grad.dispatch -def infidelity( # noqa: F811 - vstate: FullSumState, - op: InfidelityOperatorStandard, - *, - mutable, -): - if op.hilbert != vstate.hilbert: - raise TypeError("Hilbert spaces should match") - if not isinstance(op.target, FullSumState): - raise TypeError("Can only compute infidelity of exact states.") - - return infidelity_sampling_FullSumState( - vstate._apply_fun, - vstate.parameters, - vstate.model_state, - vstate._all_states, - op.target.to_array(), - return_grad=True, - ) - - -@partial(jax.jit, static_argnames=("afun", "return_grad")) -def infidelity_sampling_FullSumState( - afun, - params, - model_state, - sigma, - state_t, - return_grad, -): - def expect_fun(params): - state = jnp.exp(afun({"params": params, **model_state}, sigma)) - state = state / jnp.sqrt(jnp.sum(jnp.abs(state) ** 2)) - return jnp.abs(state.T.conj() @ state_t) ** 2 - - if not return_grad: - F = expect_fun(params) - return Stats(mean=1 - F, error_of_mean=0.0, variance=0.0) - - F, F_vjp_fun = nkjax.vjp(expect_fun, params, conjugate=True) - - F_grad = F_vjp_fun(jnp.ones_like(F))[0] - F_grad = jax.tree_util.tree_map(lambda x: mpi.mpi_mean_jax(x)[0], F_grad) - I_grad = jax.tree_util.tree_map(lambda x: -x, F_grad) - I_stats = Stats(mean=1 - F, error_of_mean=0.0, variance=0.0) - - return I_stats, I_grad +from functools import partial + +import jax.numpy as jnp +import jax + +from netket import jax as nkjax +from netket.vqs import FullSumState, expect, expect_and_grad +from netket.utils import mpi +from netket.stats import Stats + +from .operator import InfidelityOperatorStandard + + +@expect.dispatch +def infidelity(vstate: FullSumState, op: InfidelityOperatorStandard): + if op.hilbert != vstate.hilbert: + raise TypeError("Hilbert spaces should match") + if not isinstance(op.target, FullSumState): + raise TypeError("Can only compute infidelity of exact states.") + + return infidelity_sampling_FullSumState( + vstate._apply_fun, + vstate.parameters, + vstate.model_state, + vstate._all_states, + op.target.to_array(), + return_grad=False, + ) + + +@expect_and_grad.dispatch +def infidelity( # noqa: F811 + vstate: FullSumState, + op: InfidelityOperatorStandard, + *, + mutable, +): + if op.hilbert != vstate.hilbert: + raise TypeError("Hilbert spaces should match") + if not isinstance(op.target, FullSumState): + raise TypeError("Can only compute infidelity of exact states.") + + return infidelity_sampling_FullSumState( + vstate._apply_fun, + vstate.parameters, + vstate.model_state, + vstate._all_states, + op.target.to_array(), + return_grad=True, + ) + + +@partial(jax.jit, static_argnames=("afun", "return_grad")) +def infidelity_sampling_FullSumState( + afun, + params, + model_state, + sigma, + state_t, + return_grad, +): + def expect_fun(params): + state = jnp.exp(afun({"params": params, **model_state}, sigma)) + state = state / jnp.sqrt(jnp.sum(jnp.abs(state) ** 2)) + return jnp.abs(state.T.conj() @ state_t) ** 2 + + if not return_grad: + F = expect_fun(params) + return Stats(mean=1 - F, error_of_mean=0.0, variance=0.0) + + F, F_vjp_fun = nkjax.vjp(expect_fun, params, conjugate=True) + + F_grad = F_vjp_fun(jnp.ones_like(F))[0] + F_grad = jax.tree_util.tree_map(lambda x: mpi.mpi_mean_jax(x)[0], F_grad) + I_grad = jax.tree_util.tree_map(lambda x: -x, F_grad) + I_stats = Stats(mean=1 - F, error_of_mean=0.0, variance=0.0) + + return I_stats, I_grad diff --git a/netket_fidelity/infidelity/overlap/expect.py b/netket_fidelity/infidelity/overlap/expect.py index 8d342ce..62000ca 100644 --- a/netket_fidelity/infidelity/overlap/expect.py +++ b/netket_fidelity/infidelity/overlap/expect.py @@ -1,130 +1,153 @@ -from functools import partial - -import jax.numpy as jnp -import jax - -from netket.vqs import MCState, expect, expect_and_grad -from netket import jax as nkjax -from netket.utils import mpi - - -from .operator import InfidelityOperatorStandard - - -@expect.dispatch -def infidelity(vstate: MCState, op: InfidelityOperatorStandard, chunk_size: None): - if op.hilbert != vstate.hilbert: - raise TypeError("Hilbert spaces should match") - - return infidelity_sampling_MCState( - vstate._apply_fun, - op.target._apply_fun, - vstate.parameters, - op.target.parameters, - vstate.model_state, - op.target.model_state, - vstate.samples, - op.target.samples, - op.cv_coeff, - return_grad=False, - ) - - -@expect_and_grad.dispatch -def infidelity( # noqa: F811 - vstate: MCState, - op: InfidelityOperatorStandard, - chunk_size: None, - *, - mutable, -): - if op.hilbert != vstate.hilbert: - raise TypeError("Hilbert spaces should match") - - return infidelity_sampling_MCState( - vstate._apply_fun, - op.target._apply_fun, - vstate.parameters, - op.target.parameters, - vstate.model_state, - op.target.model_state, - vstate.samples, - op.target.samples, - op.cv_coeff, - return_grad=True, - ) - - -@partial(jax.jit, static_argnames=("afun", "afun_t", "return_grad")) -def infidelity_sampling_MCState( - afun, - afun_t, - params, - params_t, - model_state, - model_state_t, - sigma, - sigma_t, - cv_coeff, - return_grad, -): - N = sigma.shape[-1] - n_chains_t = sigma_t.shape[-2] - - σ = sigma.reshape(-1, N) - σ_t = sigma_t.reshape(-1, N) - - def expect_kernel(params): - def kernel_fun(params_all, samples_all): - params, params_t = params_all - σ, σ_t = samples_all - - W = {"params": params, **model_state} - W_t = {"params": params_t, **model_state_t} - - log_val = afun_t(W_t, σ) + afun(W, σ_t) - afun(W, σ) - afun_t(W_t, σ_t) - res = jnp.exp(log_val).real - if cv_coeff is not None: - res = res + cv_coeff * (jnp.exp(2 * log_val.real) - 1) - return res - - log_pdf = lambda params, σ: 2 * afun({"params": params, **model_state}, σ).real - log_pdf_t = ( - lambda params, σ: 2 * afun_t({"params": params, **model_state_t}, σ).real - ) - - def log_pdf_joint(params_all, samples_all): - params, params_t = params_all - σ, σ_t = samples_all - log_pdf_vals = log_pdf(params, σ) - log_pdf_t_vals = log_pdf_t(params_t, σ_t) - return log_pdf_vals + log_pdf_t_vals - - return nkjax.expect( - log_pdf_joint, - kernel_fun, - ( - params, - params_t, - ), - ( - σ, - σ_t, - ), - n_chains=n_chains_t, - ) - - if not return_grad: - F, F_stats = expect_kernel(params) - return F_stats.replace(mean=1 - F) - - F, F_vjp_fun, F_stats = nkjax.vjp( - expect_kernel, params, has_aux=True, conjugate=True - ) - - F_grad = F_vjp_fun(jnp.ones_like(F))[0] - F_grad = jax.tree_util.tree_map(lambda x: mpi.mpi_mean_jax(x)[0], F_grad) - I_grad = jax.tree_util.tree_map(lambda x: -x, F_grad) - I_stats = F_stats.replace(mean=1 - F) - - return I_stats, I_grad +from functools import partial + +import jax.numpy as jnp +import jax +import netket as nk + +from netket.vqs import MCState, expect, expect_and_grad +from netket import jax as nkjax +from netket.utils import mpi + + +from .operator import InfidelityOperatorStandard + + +@expect.dispatch +def infidelity(vstate: MCState, op: InfidelityOperatorStandard, chunk_size: None): + if op.hilbert != vstate.hilbert: + raise TypeError("Hilbert spaces should match") + + b = op.target.expect_and_grad(op.ham)[1] + + return infidelity_sampling_MCState( + vstate._apply_fun, + op.target._apply_fun, + vstate.parameters, + op.target.parameters, + vstate.model_state, + op.target.model_state, + vstate.samples, + op.target.samples, + op.cv_coeff, + op.delta_tau, + b, + return_grad=False, + ) + + +@expect_and_grad.dispatch +def infidelity( # noqa: F811 + vstate: MCState, + op: InfidelityOperatorStandard, + chunk_size: None, + *, + mutable, +): + if op.hilbert != vstate.hilbert: + raise TypeError("Hilbert spaces should match") + + print(op.ham) + + + b = op.target.expect_and_grad(op.ham)[1] + return infidelity_sampling_MCState( + vstate._apply_fun, + op.target._apply_fun, + vstate.parameters, + op.target.parameters, + vstate.model_state, + op.target.model_state, + vstate.samples, + op.target.samples, + op.cv_coeff, + op.delta_tau, + b, + return_grad=True, + ) + + +@partial(jax.jit, static_argnames=("afun", "afun_t", "return_grad")) +def infidelity_sampling_MCState( + afun, + afun_t, + params, + params_t, + model_state, + model_state_t, + sigma, + sigma_t, + cv_coeff, + delta_tau, + b, + return_grad, +): + N = sigma.shape[-1] + n_chains_t = sigma_t.shape[-2] + + σ = sigma.reshape(-1, N) + σ_t = sigma_t.reshape(-1, N) + + def expect_kernel(params): + def kernel_fun(params_all, samples_all): + params, params_t = params_all + σ, σ_t = samples_all + + W = {"params": params, **model_state} + W_t = {"params": params_t, **model_state_t} + + log_val = afun_t(W_t, σ) + afun(W, σ_t) - afun(W, σ) - afun_t(W_t, σ_t) + res = jnp.exp(log_val).real + if cv_coeff is not None: + res = res + cv_coeff * (jnp.exp(2 * log_val.real) - 1) + return res + + log_pdf = lambda params, σ: 2 * afun({"params": params, **model_state}, σ).real + log_pdf_t = ( + lambda params, σ: 2 * afun_t({"params": params, **model_state_t}, σ).real + ) + + def log_pdf_joint(params_all, samples_all): + params, params_t = params_all + σ, σ_t = samples_all + log_pdf_vals = log_pdf(params, σ) + log_pdf_t_vals = log_pdf_t(params_t, σ_t) + return log_pdf_vals + log_pdf_t_vals + + return nkjax.expect( + log_pdf_joint, + kernel_fun, + ( + params, + params_t, + ), + ( + σ, + σ_t, + ), + n_chains=n_chains_t, + ) + def Lsquared(params): + params_flat = nk.jax.tree_ravel(params)[0] + params_t_flat = nk.jax.tree_ravel(params_t)[0] + delta_params_flat = params_flat - params_t_flat + b_flat = nk.jax.tree_ravel(b)[0] + return jnp.absolute(1 - expect_kernel(params)[0] - delta_tau*delta_params_flat.conj() @ b_flat)**2 + + params_flat = nk.jax.tree_ravel(params)[0] + params_t_flat = nk.jax.tree_ravel(params_t)[0] + delta_params_flat = params_flat - params_t_flat + b_flat = nk.jax.tree_ravel(b)[0] + if not return_grad: + F, F_stats = expect_kernel(params) + return F_stats.replace(mean=jnp.absolute(1 - F - delta_tau*delta_params_flat.conj() @ b_flat)**2) + + F, F_vjp_fun = nkjax.vjp( + Lsquared, params, has_aux=False, conjugate=True + ) + + F_grad = F_vjp_fun(jnp.ones_like(F))[0] + F_grad = jax.tree_util.tree_map(lambda x: mpi.mpi_mean_jax(x)[0], F_grad) + #I_grad = jax.tree_util.tree_map(lambda x, y: -x-delta_tau*y, F_grad, b) + #I_stats = F_stats.replace(mean=jnp.absolute(1 - F - delta_tau*delta_params_flat.conj() @ b_flat)**2) + + return F, F_grad diff --git a/netket_fidelity/infidelity/overlap/operator.py b/netket_fidelity/infidelity/overlap/operator.py index 9835778..661ca89 100644 --- a/netket_fidelity/infidelity/overlap/operator.py +++ b/netket_fidelity/infidelity/overlap/operator.py @@ -1,82 +1,94 @@ -from typing import Optional - -import jax.numpy as jnp - -from netket.experimental.observable import AbstractObservable - -from netket.operator import AbstractOperator, DiscreteJaxOperator -from netket.utils.types import DType -from netket.utils.numbers import is_scalar -from netket.vqs import VariationalState, MCState, FullSumState - -from netket_fidelity.utils.sampling_Ustate import make_logpsi_U_afun - - -class InfidelityOperatorStandard(AbstractObservable): - def __init__( - self, - target: VariationalState, - *, - cv_coeff: Optional[float] = None, - dtype: Optional[DType] = None, - ): - super().__init__(target.hilbert) - - if not isinstance(target, VariationalState): - raise TypeError("The first argument should be a variational target.") - - if cv_coeff is not None: - cv_coeff = jnp.array(cv_coeff) - - if (not is_scalar(cv_coeff)) or jnp.iscomplex(cv_coeff): - raise TypeError("`cv_coeff` should be a real scalar number or None.") - - if isinstance(target, FullSumState): - cv_coeff = None - - self._target = target - self._cv_coeff = cv_coeff - self._dtype = dtype - - @property - def target(self): - return self._target - - @property - def cv_coeff(self): - return self._cv_coeff - - @property - def dtype(self): - return self._dtype - - @property - def is_hermitian(self): - return True - - def __repr__(self): - return f"InfidelityOperator(target={self.target}, cv_coeff={self.cv_coeff})" - - -def InfidelityUPsi( - U: AbstractOperator, - state: VariationalState, - *, - cv_coeff: Optional[float] = None, - dtype: Optional[DType] = None, -): - if not isinstance(U, DiscreteJaxOperator): - raise TypeError( - "In order to sample from the state U|psi>, U must be" - "an instance of DiscreteJaxOperator." - ) - - logpsiU, variables_U = make_logpsi_U_afun(state._apply_fun, U, state.variables) - target = MCState( - sampler=state.sampler, - apply_fun=logpsiU, - n_samples=state.n_samples, - variables=variables_U, - ) - - return InfidelityOperatorStandard(target, cv_coeff=cv_coeff, dtype=dtype) +from typing import Optional + +import jax.numpy as jnp + +from netket.experimental.observable import AbstractObservable + +from netket.operator import AbstractOperator, DiscreteJaxOperator +from netket.utils.types import DType +from netket.utils.numbers import is_scalar +from netket.vqs import VariationalState, MCState, FullSumState + +from netket_fidelity.utils.sampling_Ustate import make_logpsi_U_afun + + +class InfidelityOperatorStandard(AbstractObservable): + def __init__( + self, + target: VariationalState, + delta_tau: Optional[float] = None, + ham: Optional[AbstractOperator] = None, + *, + cv_coeff: Optional[float] = None, + dtype: Optional[DType] = None, + ): + super().__init__(target.hilbert) + + if not isinstance(target, VariationalState): + raise TypeError("The first argument should be a variational target.") + + if cv_coeff is not None: + cv_coeff = jnp.array(cv_coeff) + + if (not is_scalar(cv_coeff)) or jnp.iscomplex(cv_coeff): + raise TypeError("`cv_coeff` should be a real scalar number or None.") + + if isinstance(target, FullSumState): + cv_coeff = None + + self._target = target + self._cv_coeff = cv_coeff + self._dtype = dtype + self._delta_tau = delta_tau + self._ham = ham + + @property + def target(self): + return self._target + + @property + def cv_coeff(self): + return self._cv_coeff + + @property + def dtype(self): + return self._dtype + + @property + def is_hermitian(self): + return True + + @property + def delta_tau(self): + return self._delta_tau + + @property + def ham(self): + return self._ham + + def __repr__(self): + return f"InfidelityOperator(target={self.target}, cv_coeff={self.cv_coeff})" + + +def InfidelityUPsi( + U: AbstractOperator, + state: VariationalState, + *, + cv_coeff: Optional[float] = None, + dtype: Optional[DType] = None, +): + if not isinstance(U, DiscreteJaxOperator): + raise TypeError( + "In order to sample from the state U|psi>, U must be" + "an instance of DiscreteJaxOperator." + ) + + logpsiU, variables_U = make_logpsi_U_afun(state._apply_fun, U, state.variables) + target = MCState( + sampler=state.sampler, + apply_fun=logpsiU, + n_samples=state.n_samples, + variables=variables_U, + ) + + return InfidelityOperatorStandard(target, cv_coeff=cv_coeff, dtype=dtype) diff --git a/netket_fidelity/infidelity/overlap_U/__init__.py b/netket_fidelity/infidelity/overlap_U/__init__.py index 05576e6..94c1a7f 100644 --- a/netket_fidelity/infidelity/overlap_U/__init__.py +++ b/netket_fidelity/infidelity/overlap_U/__init__.py @@ -1,4 +1,4 @@ -from .operator import InfidelityOperatorUPsi - -from . import expect -from . import exact +from .operator import InfidelityOperatorUPsi + +from . import expect +from . import exact diff --git a/netket_fidelity/infidelity/overlap_U/exact.py b/netket_fidelity/infidelity/overlap_U/exact.py index 2819ea5..0b6c380 100644 --- a/netket_fidelity/infidelity/overlap_U/exact.py +++ b/netket_fidelity/infidelity/overlap_U/exact.py @@ -1,89 +1,89 @@ -import jax.numpy as jnp -import jax -from functools import partial - -from netket import jax as nkjax -from netket.vqs import FullSumState, expect, expect_and_grad -from netket.utils import mpi -from netket.stats import Stats - -from .operator import InfidelityOperatorUPsi - - -def sparsify(U): - return U.to_sparse() - - -@expect.dispatch -def infidelity(vstate: FullSumState, op: InfidelityOperatorUPsi): - if op.hilbert != vstate.hilbert: - raise TypeError("Hilbert spaces should match") - if not isinstance(op.target, FullSumState): - raise TypeError("Can only compute infidelity of exact states.") - - U_sp = sparsify(op._U) - Ustate_t = U_sp @ op.target.to_array(normalize=False) - - return infidelity_sampling_FullSumState( - vstate._apply_fun, - vstate.parameters, - vstate.model_state, - vstate._all_states, - Ustate_t, - return_grad=False, - ) - - -@expect_and_grad.dispatch -def infidelity( # noqa: F811 - vstate: FullSumState, - op: InfidelityOperatorUPsi, - *, - mutable, -): - if op.hilbert != vstate.hilbert: - raise TypeError("Hilbert spaces should match") - if not isinstance(op.target, FullSumState): - raise TypeError("Can only compute infidelity of exact states.") - - U_sp = sparsify(op._U) - Ustate_t = U_sp @ op.target.to_array(normalize=False) - - return infidelity_sampling_FullSumState( - vstate._apply_fun, - vstate.parameters, - vstate.model_state, - vstate._all_states, - Ustate_t, - return_grad=True, - ) - - -@partial(jax.jit, static_argnames=("afun", "return_grad")) -def infidelity_sampling_FullSumState( - afun, - params, - model_state, - sigma, - Ustate_t, - return_grad, -): - def expect_fun(params): - state = jnp.exp(afun({"params": params, **model_state}, sigma)) - state = state / jnp.sqrt(jnp.sum(jnp.abs(state) ** 2)) - return jnp.abs(state.T.conj().T @ Ustate_t) ** 2 / ( - Ustate_t.conj().T @ Ustate_t - ) - - if not return_grad: - F = expect_fun(params) - return Stats(mean=1 - F, error_of_mean=0.0, variance=0.0) - - F, F_vjp_fun = nkjax.vjp(expect_fun, params, conjugate=True) - - F_grad = F_vjp_fun(jnp.ones_like(F))[0] - F_grad = jax.tree_util.tree_map(lambda x: mpi.mpi_mean_jax(x)[0], F_grad) - I_grad = jax.tree_util.tree_map(lambda x: -x, F_grad) - I_stats = Stats(mean=1 - F, error_of_mean=0.0, variance=0.0) - - return I_stats, I_grad +import jax.numpy as jnp +import jax +from functools import partial + +from netket import jax as nkjax +from netket.vqs import FullSumState, expect, expect_and_grad +from netket.utils import mpi +from netket.stats import Stats + +from .operator import InfidelityOperatorUPsi + + +def sparsify(U): + return U.to_sparse() + + +@expect.dispatch +def infidelity(vstate: FullSumState, op: InfidelityOperatorUPsi): + if op.hilbert != vstate.hilbert: + raise TypeError("Hilbert spaces should match") + if not isinstance(op.target, FullSumState): + raise TypeError("Can only compute infidelity of exact states.") + + U_sp = sparsify(op._U) + Ustate_t = U_sp @ op.target.to_array(normalize=False) + + return infidelity_sampling_FullSumState( + vstate._apply_fun, + vstate.parameters, + vstate.model_state, + vstate._all_states, + Ustate_t, + return_grad=False, + ) + + +@expect_and_grad.dispatch +def infidelity( # noqa: F811 + vstate: FullSumState, + op: InfidelityOperatorUPsi, + *, + mutable, +): + if op.hilbert != vstate.hilbert: + raise TypeError("Hilbert spaces should match") + if not isinstance(op.target, FullSumState): + raise TypeError("Can only compute infidelity of exact states.") + + U_sp = sparsify(op._U) + Ustate_t = U_sp @ op.target.to_array(normalize=False) + + return infidelity_sampling_FullSumState( + vstate._apply_fun, + vstate.parameters, + vstate.model_state, + vstate._all_states, + Ustate_t, + return_grad=True, + ) + + +@partial(jax.jit, static_argnames=("afun", "return_grad")) +def infidelity_sampling_FullSumState( + afun, + params, + model_state, + sigma, + Ustate_t, + return_grad, +): + def expect_fun(params): + state = jnp.exp(afun({"params": params, **model_state}, sigma)) + state = state / jnp.sqrt(jnp.sum(jnp.abs(state) ** 2)) + return jnp.abs(state.T.conj().T @ Ustate_t) ** 2 / ( + Ustate_t.conj().T @ Ustate_t + ) + + if not return_grad: + F = expect_fun(params) + return Stats(mean=1 - F, error_of_mean=0.0, variance=0.0) + + F, F_vjp_fun = nkjax.vjp(expect_fun, params, conjugate=True) + + F_grad = F_vjp_fun(jnp.ones_like(F))[0] + F_grad = jax.tree_util.tree_map(lambda x: mpi.mpi_mean_jax(x)[0], F_grad) + I_grad = jax.tree_util.tree_map(lambda x: -x, F_grad) + I_stats = Stats(mean=1 - F, error_of_mean=0.0, variance=0.0) + + return I_stats, I_grad diff --git a/netket_fidelity/infidelity/overlap_U/expect.py b/netket_fidelity/infidelity/overlap_U/expect.py index cb4fac1..c902e60 100644 --- a/netket_fidelity/infidelity/overlap_U/expect.py +++ b/netket_fidelity/infidelity/overlap_U/expect.py @@ -1,161 +1,161 @@ -from functools import partial - -import jax -import jax.numpy as jnp -from jax.scipy.special import logsumexp - -from netket import jax as nkjax -from netket.operator import DiscreteJaxOperator -from netket.vqs import MCState, expect, expect_and_grad, get_local_kernel_arguments -from netket.utils import mpi - - -from .operator import InfidelityOperatorUPsi - - -@expect.dispatch -def infidelity(vstate: MCState, op: InfidelityOperatorUPsi, chunk_size: None): - if op.hilbert != vstate.hilbert: - raise TypeError("Hilbert spaces should match") - - sigma, args = get_local_kernel_arguments(vstate, op._U) - sigma_t, args_t = get_local_kernel_arguments(op.target, op._U_dagger) - - return infidelity_sampling_MCState( - vstate._apply_fun, - op.target._apply_fun, - vstate.parameters, - op.target.parameters, - vstate.model_state, - op.target.model_state, - sigma, - args, - sigma_t, - args_t, - op.cv_coeff, - return_grad=False, - ) - - -@expect_and_grad.dispatch -def infidelity( # noqa: F811 - vstate: MCState, - op: InfidelityOperatorUPsi, - chunk_size: None, - *, - mutable, -): - if op.hilbert != vstate.hilbert: - raise TypeError("Hilbert spaces should match") - - sigma, args = get_local_kernel_arguments(vstate, op._U) - sigma_t, args_t = get_local_kernel_arguments(op.target, op._U_dagger) - - return infidelity_sampling_MCState( - vstate._apply_fun, - op.target._apply_fun, - vstate.parameters, - op.target.parameters, - vstate.model_state, - op.target.model_state, - sigma, - args, - sigma_t, - args_t, - op.cv_coeff, - return_grad=True, - ) - - -@partial(jax.jit, static_argnames=("afun", "afun_t", "return_grad")) -def infidelity_sampling_MCState( - afun, - afun_t, - params, - params_t, - model_state, - model_state_t, - sigma, - args, - sigma_t, - args_t, - cv_coeff, - return_grad, -): - N = sigma.shape[-1] - n_chains_t = sigma_t.shape[-2] - - σ = sigma.reshape(-1, N) - σ_t = sigma_t.reshape(-1, N) - - if isinstance(args, DiscreteJaxOperator): - xp, mels = args.get_conn_padded(σ) - xp_t, mels_t = args_t.get_conn_padded(σ_t) - else: - xp = args[0].reshape(σ.shape[0], -1, N) - mels = args[1].reshape(σ.shape[0], -1) - xp_t = args_t[0].reshape(σ_t.shape[0], -1, N) - mels_t = args_t[1].reshape(σ_t.shape[0], -1) - - def expect_kernel(params): - def kernel_fun(params_all, samples_all): - params, params_t = params_all - σ, σ_t = samples_all - - W = {"params": params, **model_state} - W_t = {"params": params_t, **model_state_t} - - logpsi_t_xp = afun_t(W_t, xp) - logpsi_xp_t = afun(W, xp_t) - - log_val = ( - logsumexp(logpsi_t_xp, axis=-1, b=mels) - + logsumexp(logpsi_xp_t, axis=-1, b=mels_t) - - afun(W, σ) - - afun_t(W_t, σ_t) - ) - res = jnp.exp(log_val).real - if cv_coeff is not None: - res = res + cv_coeff * (jnp.exp(2 * log_val.real) - 1) - return res - - log_pdf = lambda params, σ: 2 * afun({"params": params, **model_state}, σ).real - log_pdf_t = ( - lambda params, σ: 2 * afun_t({"params": params, **model_state_t}, σ).real - ) - - def log_pdf_joint(params_all, samples_all): - params, params_t = params_all - σ, σ_t = samples_all - log_pdf_vals = log_pdf(params, σ) - log_pdf_t_vals = log_pdf_t(params_t, σ_t) - return log_pdf_vals + log_pdf_t_vals - - return nkjax.expect( - log_pdf_joint, - kernel_fun, - ( - params, - params_t, - ), - ( - σ, - σ_t, - ), - n_chains=n_chains_t, - ) - - if not return_grad: - F, F_stats = expect_kernel(params) - return F_stats.replace(mean=1 - F) - - F, F_vjp_fun, F_stats = nkjax.vjp( - expect_kernel, params, has_aux=True, conjugate=True - ) - - F_grad = F_vjp_fun(jnp.ones_like(F))[0] - F_grad = jax.tree_util.tree_map(lambda x: mpi.mpi_mean_jax(x)[0], F_grad) - I_grad = jax.tree_util.tree_map(lambda x: -x, F_grad) - I_stats = F_stats.replace(mean=1 - F) - - return I_stats, I_grad +from functools import partial + +import jax +import jax.numpy as jnp +from jax.scipy.special import logsumexp + +from netket import jax as nkjax +from netket.operator import DiscreteJaxOperator +from netket.vqs import MCState, expect, expect_and_grad, get_local_kernel_arguments +from netket.utils import mpi + + +from .operator import InfidelityOperatorUPsi + + +@expect.dispatch +def infidelity(vstate: MCState, op: InfidelityOperatorUPsi, chunk_size: None): + if op.hilbert != vstate.hilbert: + raise TypeError("Hilbert spaces should match") + + sigma, args = get_local_kernel_arguments(vstate, op._U) + sigma_t, args_t = get_local_kernel_arguments(op.target, op._U_dagger) + + return infidelity_sampling_MCState( + vstate._apply_fun, + op.target._apply_fun, + vstate.parameters, + op.target.parameters, + vstate.model_state, + op.target.model_state, + sigma, + args, + sigma_t, + args_t, + op.cv_coeff, + return_grad=False, + ) + + +@expect_and_grad.dispatch +def infidelity( # noqa: F811 + vstate: MCState, + op: InfidelityOperatorUPsi, + chunk_size: None, + *, + mutable, +): + if op.hilbert != vstate.hilbert: + raise TypeError("Hilbert spaces should match") + + sigma, args = get_local_kernel_arguments(vstate, op._U) + sigma_t, args_t = get_local_kernel_arguments(op.target, op._U_dagger) + + return infidelity_sampling_MCState( + vstate._apply_fun, + op.target._apply_fun, + vstate.parameters, + op.target.parameters, + vstate.model_state, + op.target.model_state, + sigma, + args, + sigma_t, + args_t, + op.cv_coeff, + return_grad=True, + ) + + +@partial(jax.jit, static_argnames=("afun", "afun_t", "return_grad")) +def infidelity_sampling_MCState( + afun, + afun_t, + params, + params_t, + model_state, + model_state_t, + sigma, + args, + sigma_t, + args_t, + cv_coeff, + return_grad, +): + N = sigma.shape[-1] + n_chains_t = sigma_t.shape[-2] + + σ = sigma.reshape(-1, N) + σ_t = sigma_t.reshape(-1, N) + + if isinstance(args, DiscreteJaxOperator): + xp, mels = args.get_conn_padded(σ) + xp_t, mels_t = args_t.get_conn_padded(σ_t) + else: + xp = args[0].reshape(σ.shape[0], -1, N) + mels = args[1].reshape(σ.shape[0], -1) + xp_t = args_t[0].reshape(σ_t.shape[0], -1, N) + mels_t = args_t[1].reshape(σ_t.shape[0], -1) + + def expect_kernel(params): + def kernel_fun(params_all, samples_all): + params, params_t = params_all + σ, σ_t = samples_all + + W = {"params": params, **model_state} + W_t = {"params": params_t, **model_state_t} + + logpsi_t_xp = afun_t(W_t, xp) + logpsi_xp_t = afun(W, xp_t) + + log_val = ( + logsumexp(logpsi_t_xp, axis=-1, b=mels) + + logsumexp(logpsi_xp_t, axis=-1, b=mels_t) + - afun(W, σ) + - afun_t(W_t, σ_t) + ) + res = jnp.exp(log_val).real + if cv_coeff is not None: + res = res + cv_coeff * (jnp.exp(2 * log_val.real) - 1) + return res + + log_pdf = lambda params, σ: 2 * afun({"params": params, **model_state}, σ).real + log_pdf_t = ( + lambda params, σ: 2 * afun_t({"params": params, **model_state_t}, σ).real + ) + + def log_pdf_joint(params_all, samples_all): + params, params_t = params_all + σ, σ_t = samples_all + log_pdf_vals = log_pdf(params, σ) + log_pdf_t_vals = log_pdf_t(params_t, σ_t) + return log_pdf_vals + log_pdf_t_vals + + return nkjax.expect( + log_pdf_joint, + kernel_fun, + ( + params, + params_t, + ), + ( + σ, + σ_t, + ), + n_chains=n_chains_t, + ) + + if not return_grad: + F, F_stats = expect_kernel(params) + return F_stats.replace(mean=1 - F) + + F, F_vjp_fun, F_stats = nkjax.vjp( + expect_kernel, params, has_aux=True, conjugate=True + ) + + F_grad = F_vjp_fun(jnp.ones_like(F))[0] + F_grad = jax.tree_util.tree_map(lambda x: mpi.mpi_mean_jax(x)[0], F_grad) + I_grad = jax.tree_util.tree_map(lambda x: -x, F_grad) + I_stats = F_stats.replace(mean=1 - F) + + return I_stats, I_grad diff --git a/netket_fidelity/infidelity/overlap_U/operator.py b/netket_fidelity/infidelity/overlap_U/operator.py index dc6f70e..2d061cc 100644 --- a/netket_fidelity/infidelity/overlap_U/operator.py +++ b/netket_fidelity/infidelity/overlap_U/operator.py @@ -1,67 +1,67 @@ -from typing import Optional -import jax.numpy as jnp - -from netket.experimental.observable import AbstractObservable - -from netket.operator import AbstractOperator -from netket.utils.types import DType -from netket.utils.numbers import is_scalar -from netket.vqs import VariationalState, FullSumState - - -class InfidelityOperatorUPsi(AbstractObservable): - def __init__( - self, - U: AbstractOperator, - state: VariationalState, - *, - cv_coeff: Optional[float] = None, - U_dagger: AbstractOperator, - is_unitary: bool = False, - dtype: Optional[DType] = None, - ): - super().__init__(state.hilbert) - - if not isinstance(state, VariationalState): - raise TypeError("The first argument should be a variational state.") - - if not is_unitary and not isinstance(state, FullSumState): - raise ValueError( - "Only works with unitary gates. If the gate is non unitary " - "then you must sample from it. Use a different operator." - ) - - if cv_coeff is not None: - cv_coeff = jnp.array(cv_coeff) - - if (not is_scalar(cv_coeff)) or jnp.iscomplex(cv_coeff): - raise TypeError("`cv_coeff` should be a real scalar number or None.") - - if isinstance(state, FullSumState): - cv_coeff = None - - self._target = state - self._cv_coeff = cv_coeff - self._dtype = dtype - - self._U = U - self._U_dagger = U_dagger - - @property - def target(self): - return self._target - - @property - def cv_coeff(self): - return self._cv_coeff - - @property - def dtype(self): - return self._dtype - - @property - def is_hermitian(self): - return True - - def __repr__(self): - return f"InfidelityOperatorUPsi(target=U@{self.target}, U={self._U}, cv_coeff={self.cv_coeff})" +from typing import Optional +import jax.numpy as jnp + +from netket.experimental.observable import AbstractObservable + +from netket.operator import AbstractOperator +from netket.utils.types import DType +from netket.utils.numbers import is_scalar +from netket.vqs import VariationalState, FullSumState + + +class InfidelityOperatorUPsi(AbstractObservable): + def __init__( + self, + U: AbstractOperator, + state: VariationalState, + *, + cv_coeff: Optional[float] = None, + U_dagger: AbstractOperator, + is_unitary: bool = False, + dtype: Optional[DType] = None, + ): + super().__init__(state.hilbert) + + if not isinstance(state, VariationalState): + raise TypeError("The first argument should be a variational state.") + + if not is_unitary and not isinstance(state, FullSumState): + raise ValueError( + "Only works with unitary gates. If the gate is non unitary " + "then you must sample from it. Use a different operator." + ) + + if cv_coeff is not None: + cv_coeff = jnp.array(cv_coeff) + + if (not is_scalar(cv_coeff)) or jnp.iscomplex(cv_coeff): + raise TypeError("`cv_coeff` should be a real scalar number or None.") + + if isinstance(state, FullSumState): + cv_coeff = None + + self._target = state + self._cv_coeff = cv_coeff + self._dtype = dtype + + self._U = U + self._U_dagger = U_dagger + + @property + def target(self): + return self._target + + @property + def cv_coeff(self): + return self._cv_coeff + + @property + def dtype(self): + return self._dtype + + @property + def is_hermitian(self): + return True + + def __repr__(self): + return f"InfidelityOperatorUPsi(target=U@{self.target}, U={self._U}, cv_coeff={self.cv_coeff})"