diff --git a/test/test_operator.py b/test/test_operator.py index 316c316..9ce3dbc 100644 --- a/test/test_operator.py +++ b/test/test_operator.py @@ -3,7 +3,6 @@ import numpy as np from netket_fidelity.operator import singlequbit_gates as sg from jax import numpy as jnp - import netket_fidelity as nkf @@ -29,40 +28,83 @@ def test_operator_dense_and_conversion(operator): assert operator.hilbert == operator.to_local_operator().hilbert -def test_get_conns(): +def test_get_conns_and_mels(): hi_spin = nk.hilbert.Spin(s=0.5, N=3) hi_qubit = nk.hilbert.Qubit(N=3) local_state_spin = hi_spin.local_states local_state_qubit = hi_qubit.local_states - sigma_4_qubit = hi_qubit.numbers_to_states(2) + sigma_2_qubit = hi_qubit.numbers_to_states(2) sigma_7_qubit = hi_qubit.numbers_to_states(7) - sigma_4_spin = hi_spin.numbers_to_states(2) + sigma_2_spin = hi_spin.numbers_to_states(2) sigma_7_spin = hi_spin.numbers_to_states(7) - sigma_qubit = jnp.array([sigma_4_qubit, sigma_7_qubit]) - sigma_spin = jnp.array([sigma_4_spin, sigma_7_spin]) + sigma_qubit = jnp.array([sigma_2_qubit, sigma_7_qubit]) + sigma_spin = jnp.array([sigma_2_spin, sigma_7_spin]) - conns_rx_qubit, _ = sg.get_conns_and_mels_Rx(sigma_qubit, 0, 0, local_state_qubit) - conns_ry_qubit, _ = sg.get_conns_and_mels_Ry(sigma_qubit, 0, 0, local_state_qubit) - conns_h_qubit, _ = sg.get_conns_and_mels_Hadamard(sigma_qubit, 0, local_state_qubit) + conns_rx_qubit, mels_rx_qubit = sg.get_conns_and_mels_Rx( + sigma_qubit, 0, np.pi / 2, local_state_qubit + ) + conns_ry_qubit, mels_ry_qubit = sg.get_conns_and_mels_Ry( + sigma_qubit, 0, np.pi / 2, local_state_qubit + ) + conns_h_qubit, mels_h_qubit = sg.get_conns_and_mels_Hadamard( + sigma_qubit, 0, local_state_qubit + ) - conns_rx_spin, _ = sg.get_conns_and_mels_Rx(sigma_spin, 0, 0, local_state_spin) - conns_ry_spin, _ = sg.get_conns_and_mels_Ry(sigma_spin, 0, 0, local_state_spin) - conns_h_spin, _ = sg.get_conns_and_mels_Hadamard(sigma_spin, 0, local_state_spin) + conns_rx_spin, mels_rx_spin = sg.get_conns_and_mels_Rx( + sigma_spin, 0, np.pi / 2, local_state_spin + ) + conns_ry_spin, mels_ry_spin = sg.get_conns_and_mels_Ry( + sigma_spin, 0, np.pi / 2, local_state_spin + ) + conns_h_spin, mels_h_spin = sg.get_conns_and_mels_Hadamard( + sigma_spin, 0, local_state_spin + ) - values_check_qubit = jnp.array( + conns_check_qubit = jnp.array( [[[0.0, 1.0, 0.0], [1.0, 1.0, 0.0]], [[1.0, 1.0, 1.0], [0.0, 1.0, 1.0]]] ) - values_check_spin = jnp.array( + conns_check_spin = jnp.array( [[[-1.0, 1.0, -1.0], [1.0, 1.0, -1.0]], [[1.0, 1.0, 1.0], [-1.0, 1.0, 1.0]]] ) - assert (conns_rx_qubit == values_check_qubit).all() - assert (conns_ry_qubit == values_check_qubit).all() - assert (conns_h_qubit == values_check_qubit).all() - assert (conns_rx_spin == values_check_spin).all() - assert (conns_ry_spin == values_check_spin).all() - assert (conns_h_spin == values_check_spin).all() + mels_check_qubit_rx = jnp.array( + [[0.70710678 + 0.0j, 0.0 - 0.70710678j], [0.70710678 + 0.0j, 0.0 - 0.70710678j]] + ) + mels_check_qubit_ry = jnp.array( + [ + [0.70710678 + 0.0j, 0.70710678 + 0.0j], + [0.70710678 + 0.0j, -0.70710678 + 0.0j], + ] + ) + mels_check_qubit_h = jnp.array( + [[0.70710678, 0.70710678], [-0.70710678, 0.70710678]] + ) + + mels_check_spin_rx = jnp.array( + [[0.70710678 + 0.0j, 0.0 - 0.70710678j], [0.70710678 + 0.0j, 0.0 - 0.70710678j]] + ) + mels_check_spin_ry = jnp.array( + [ + [0.70710678 + 0.0j, 0.70710678 + 0.0j], + [0.70710678 + 0.0j, -0.70710678 + 0.0j], + ] + ) + mels_check_spin_h = jnp.array([[0.70710678, 0.70710678], [-0.70710678, 0.70710678]]) + + np.testing.assert_allclose(conns_rx_qubit, conns_check_qubit) + np.testing.assert_allclose(conns_ry_qubit, conns_check_qubit) + np.testing.assert_allclose(conns_h_qubit, conns_check_qubit) + np.testing.assert_allclose(conns_rx_spin, conns_check_spin) + np.testing.assert_allclose(conns_ry_spin, conns_check_spin) + np.testing.assert_allclose(conns_h_spin, conns_check_spin) + + np.testing.assert_allclose(mels_rx_qubit, mels_check_qubit_rx) + np.testing.assert_allclose(mels_ry_qubit, mels_check_qubit_ry) + np.testing.assert_allclose(mels_h_qubit, mels_check_qubit_h) + np.testing.assert_allclose(mels_rx_spin, mels_check_spin_rx) + np.testing.assert_allclose(mels_ry_spin, mels_check_spin_ry) + np.testing.assert_allclose(mels_h_spin, mels_check_spin_h)