diff --git a/pdm.lock b/pdm.lock index 83ff8ce88..4088a1dd8 100644 --- a/pdm.lock +++ b/pdm.lock @@ -5,7 +5,7 @@ groups = ["default", "dev", "doc"] strategy = ["cross_platform", "inherit_metadata"] lock_version = "4.4.1" -content_hash = "sha256:84eed562c36a88f6443e9407272b31867720ae9b67aa26a57b769415b2dd03de" +content_hash = "sha256:9898bdd0fbd80f75a5944c00a27603a402b71a613893d1e67c01ecc1c702ebcf" [[package]] name = "amazon-braket-default-simulator" @@ -4129,6 +4129,20 @@ files = [ {file = "tornado-6.4.tar.gz", hash = "sha256:72291fa6e6bc84e626589f1c29d90a5a6d593ef5ae68052ee2ef000dfd273dee"}, ] +[[package]] +name = "tqdm" +version = "4.66.4" +requires_python = ">=3.7" +summary = "Fast, Extensible Progress Meter" +groups = ["dev"] +dependencies = [ + "colorama; platform_system == \"Windows\"", +] +files = [ + {file = "tqdm-4.66.4-py3-none-any.whl", hash = "sha256:b75ca56b413b030bc3f00af51fd2c1a1a5eac6a0c1cca83cbb37a5c52abce644"}, + {file = "tqdm-4.66.4.tar.gz", hash = "sha256:e4d936c9de8727928f3be6079590e97d9abfe8d39a590be678eb5919ffc186bb"}, +] + [[package]] name = "traitlets" version = "5.14.1" diff --git a/pyproject.toml b/pyproject.toml index f4632342d..4d68eb005 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -101,6 +101,7 @@ dev = [ "scikit-optimize>=0.9.0", "matplotlib>=3.8.1", "icecream>=2.1.3", + "tqdm>=4.66.4", ] [tool.pdm.scripts] diff --git a/src/bloqade/emulate/ir/emulator.py b/src/bloqade/emulate/ir/emulator.py index 48220e0ba..8bde0fdcd 100644 --- a/src/bloqade/emulate/ir/emulator.py +++ b/src/bloqade/emulate/ir/emulator.py @@ -172,15 +172,35 @@ def __post_init__(self): def __len__(self): return len(self.sites) + def same(self, other: Any) -> bool: + if not isinstance(other, Register): + return False + + return ( + self.atom_type == other.atom_type + and self.blockade_radius == other.blockade_radius + and set(self.sites) == set(other.sites) + ) + + # for hashing/comparing I only want to compare based on the effective + # fock space that gets generated which is determined by the geometry + # only if the blockade radius is non-zero. I overload the __eq__ and + # __hash__ methods to achieve this. + def __eq__(self, other: Any): - if isinstance(other, Register): - return ( - self.atom_type == other.atom_type - and self.blockade_radius == other.blockade_radius - and set(self.sites) == set(other.sites) + if not isinstance(other, Register): + return False + + if self.blockade_radius == Decimal("0") and other.blockade_radius == Decimal( + "0" + ): + # if blockade radius is zero, then the positions are irrelevant + # because the fock states generated by the geometry are the same + return self.atom_type == other.atom_type and len(self.sites) == len( + other.sites ) - return False + return self.same(other) def __hash__(self) -> int: if self.blockade_radius == Decimal("0"): diff --git a/src/bloqade/emulate/ir/space.py b/src/bloqade/emulate/ir/space.py index e12f188f3..3b97478ca 100644 --- a/src/bloqade/emulate/ir/space.py +++ b/src/bloqade/emulate/ir/space.py @@ -1,12 +1,13 @@ from dataclasses import dataclass from numpy.typing import NDArray -from beartype.typing import TYPE_CHECKING +from beartype.typing import TYPE_CHECKING, Any import numpy as np from enum import Enum if TYPE_CHECKING: from .emulator import Register from .atom_type import AtomType + from .state_vector import StateVector MAX_PRINT_SIZE = 30 @@ -28,6 +29,12 @@ def __post_init__(self): assert isinstance(self.program_register, Register) + def __eq__(self, other: Any) -> bool: + if not isinstance(other, Space): + return False + + return self.program_register == other.program_register + @classmethod def create(cls, register: "Register"): sites = register.sites @@ -193,10 +200,12 @@ def index_to_fock_state(self, index: int) -> str: self.configurations[index], self.n_atoms ) - def zero_state(self, dtype=np.float64) -> NDArray: + def zero_state(self, dtype=np.float64) -> "StateVector": + from .state_vector import StateVector + state = np.zeros(self.size, dtype=dtype) state[0] = 1.0 - return state + return StateVector(state, self) def sample_state_vector( self, state_vector: NDArray, n_samples: int, project_hyperfine: bool = True diff --git a/src/bloqade/emulate/ir/state_vector.py b/src/bloqade/emulate/ir/state_vector.py index 82176670b..33ffbc3f4 100644 --- a/src/bloqade/emulate/ir/state_vector.py +++ b/src/bloqade/emulate/ir/state_vector.py @@ -1,6 +1,6 @@ import plum from bloqade.emulate.ir.emulator import EmulatorProgram -from bloqade.emulate.ir.space import Space +from bloqade.emulate.ir.space import Space, MAX_PRINT_SIZE from bloqade.emulate.sparse_operator import ( IndexMapping, SparseMatrixCSC, @@ -8,12 +8,13 @@ ) from dataclasses import dataclass, field from numpy.typing import NDArray -from beartype.typing import List, Callable, Union, Optional, Tuple +from beartype.typing import List, Callable, Union, Optional, Tuple, Iterator, Sequence from beartype.vale import IsAttr, IsEqual from typing import Annotated from beartype import beartype import numpy as np from scipy.integrate import ode +from scipy.sparse import csr_matrix, diags from numba import njit SparseOperator = Union[IndexMapping, SparseMatrixCSR, SparseMatrixCSC] @@ -169,7 +170,7 @@ def local_trace(self, matrix: np.ndarray, site_index: int) -> complex: # noqa: op=matrix, ) - return complex(value.real, value.imag) + return complex(value.real, value.imag) / self.norm() @plum.dispatch def local_trace( # noqa: F811 @@ -198,6 +199,62 @@ def local_trace( # noqa: F811 """ ... + def sample(self, shots: int, project_hyperfine: bool = True) -> NDArray: + """Sample the state vector and return bitstrings.""" + return self.space.sample_state_vector( + self.data, shots, project_hyperfine=project_hyperfine + ) + + def normalize(self) -> None: + """Normalize the state vector.""" + data = self.data + data /= np.linalg.norm(data) + + def norm(self) -> float: + """Return the norm of the state vector.""" + return np.linalg.norm(self.data) + + def __str__(self) -> str: + output = "" + + n_digits = len(str(self.space.size - 1)) + fmt = "{{index: >{}d}}. {{fock_state:s}} {{coeff:}}\n".format(n_digits) + if self.space.size < MAX_PRINT_SIZE: + for index, state_int in enumerate(self.space.configurations): + fock_state = self.space.atom_type.integer_to_string( + state_int, self.space.n_atoms + ) + output = output + fmt.format( + index=index, fock_state=fock_state, coeff=self.data[index] + ) + + else: + lower_index = MAX_PRINT_SIZE // 2 + (MAX_PRINT_SIZE % 2) + upper_index = self.space.size - MAX_PRINT_SIZE // 2 + + for index, state_int in enumerate(self.space.configurations[:lower_index]): + fock_state = self.space.atom_type.integer_to_string( + state_int, self.space.n_atoms + ) + output = output + fmt.format( + index=index, fock_state=fock_state, coeff=self.data[index] + ) + + output += (n_digits * " ") + "...\n" + + for index, state_int in enumerate( + self.space.configurations[upper_index:], + start=self.space.size - MAX_PRINT_SIZE // 2, + ): + fock_state = self.space.atom_type.integer_to_string( + state_int, self.space.n_atoms + ) + output = output + fmt.format( + index=index, fock_state=fock_state, coeff=self.data[index] + ) + + return output + @dataclass(frozen=True) class DetuningOperator: @@ -228,6 +285,16 @@ def dot(self, register: NDArray, output: NDArray, time: float): return output + def tocsr(self, time: float) -> csr_matrix: + amplitude = self.amplitude(time) / 2 + if self.phase is None: + return self.op.tocsr() * amplitude + + amplitude: np.complexfloating = amplitude * np.exp(1j * self.phase(time)) + mat = self.op.tocsr() * amplitude + + return mat + mat.T.conj() + @dataclass(frozen=True) class RydbergHamiltonian: @@ -291,26 +358,26 @@ def _check_register(self, register: np.ndarray): def _apply( self, - register_data: np.ndarray, + register: np.ndarray, time: Optional[float] = None, output: Optional[NDArray] = None, ) -> np.ndarray: - self._check_register(register_data) + self._check_register(register) if time is None: time = self.emulator_ir.duration if output is None: - output = np.zeros_like(register_data, dtype=np.complex128) + output = np.zeros_like(register, dtype=np.complex128) diagonal = sum( (detuning.get_diagonal(time) for detuning in self.detuning_ops), + start=self.rydberg, ) - np.multiply(diagonal, register_data, out=output) - + np.multiply(diagonal, register, out=output) for rabi_op in self.rabi_ops: - rabi_op.dot(register_data, output, time) + rabi_op.dot(register, output, time) return output @@ -379,30 +446,26 @@ def variance( _, var = self.average_and_variance(register, time) return var - @plum.dispatch - def expectation_value( # noqa: F811 - self, register: np.ndarray, operator: np.ndarray, site_indices: int - ) -> complex: - """Calculate expectation values of one and two body operators. + def tocsr(self, time: float) -> csr_matrix: + """Return the Hamiltonian as a csr matrix at time `time`. Args: - register (np.ndarray): Register to evaluate expectation value with - operator (np.ndarray): Operator to take expectation value of. - site_indices (int, Tuple[int, int]): site/sites to evaluate `operator` at. - It can either a single integer or a tuple of two integers for one and - two body operator respectively. - - Raises: - ValueError: Error is raised when the dimension of `operator` is not - consistent with `site` argument. The size of the operator must fit the - size of the local hilbert space of `site` depending on the number of sites - and the number of levels inside each atom, e.g. for two site expectation v - alue with a three level atom the operator must be a 9 by 9 array. + time (float): time to evaluate the Hamiltonian at. Returns: - complex: The expectation value. + csr_matrix: The Hamiltonian as a csr matrix. + """ - self._check_register(register) + diagonal = sum( + (detuning.get_diagonal(time) for detuning in self.detuning_ops), + start=self.rydberg, + ) + + hamiltonian = diags(diagonal).tocsr() + for rabi_op in self.rabi_ops: + hamiltonian = hamiltonian + rabi_op.tocsr(time) + + return hamiltonian @dataclass(frozen=True) @@ -465,94 +528,115 @@ def _error_check(solver_name: str, status_code: int): elif solver_name in ["dop853", "dopri5"]: AnalogGate._error_check_dop(status_code) - def _apply( + def _check_args( self, - state: StateArray, - solver_name: str = "dop853", - atol: float = 1e-7, - rtol: float = 1e-14, - nsteps: int = 2_147_483_647, - times: Union[List[float], RealArray] = [], + state_vec: StateVector, + solver_name: str, + atol: float, + rtol: float, + nsteps: int, + times: Sequence[float], ): - if state is None: - state = self.hamiltonian.space.zero_state() + duration = self.hamiltonian.emulator_ir.duration + times = [duration] if len(times) == 0 else times + if state_vec is None: + state_vec = self.hamiltonian.space.zero_state(np.complex128) + + if state_vec.space != self.hamiltonian.space: + raise ValueError("State vector not in the same space as the Hamiltonian.") if solver_name not in AnalogGate.SUPPORTED_SOLVERS: raise ValueError(f"'{solver_name}' not supported.") - duration = self.hamiltonian.emulator_ir.duration + if any(time > duration or time < 0.0 for time in times): + raise ValueError( + f"Times must be between 0 and duration {duration}. found {times}" + ) - state = np.asarray(state).astype(np.complex128, copy=False) + return state_vec, solver_name, atol, rtol, nsteps, times - solver = ode(self.hamiltonian._ode_real_kernel) - solver.set_f_params(np.zeros_like(state, dtype=np.complex128)) - solver.set_initial_value(state.view(np.float64)) - solver.set_integrator(solver_name, atol=atol, rtol=rtol, nsteps=nsteps) + def _apply( + self, + state_vec: StateVector, + solver_name: str = "dop853", + atol: float = 1e-7, + rtol: float = 1e-14, + nsteps: int = 2_147_483_647, + times: Sequence[float] = (), + ) -> Iterator[StateVector]: - if any(time >= duration or time < 0.0 for time in times): - raise ValueError("Times must be between 0 and duration.") + state_vec, solver_name, atol, rtol, nsteps, times = self._check_args( + state_vec, solver_name, atol, rtol, nsteps, times + ) + state_data = np.asarray(state_vec.data).astype(np.complex128, copy=False) - times = [*times, duration] + solver = ode(self.hamiltonian._ode_real_kernel) + solver.set_f_params(np.zeros_like(state_data, dtype=np.complex128)) + solver.set_initial_value(state_data.view(np.float64)) + solver.set_integrator(solver_name, atol=atol, rtol=rtol, nsteps=nsteps) for time in times: - if time == 0.0: - yield state + if solver.t == time: + yield StateVector(solver.y.view(np.complex128), self.hamiltonian.space) continue + solver.integrate(time) AnalogGate._error_check(solver_name, solver.get_return_code()) - yield solver.y.view(np.complex128) - def _apply_interation_picture( + yield StateVector(solver.y.view(np.complex128), self.hamiltonian.space) + + def _apply_interaction_picture( self, - state: StateArray, + state_vec: StateVector, solver_name: str = "dop853", atol: float = 1e-7, rtol: float = 1e-14, nsteps: int = 2_147_483_647, - times: Union[List[float], RealArray] = [], - ): - if state is None: - state = self.hamiltonian.space.zero_state() + times: Sequence[float] = (), + ) -> Iterator[StateVector]: - if solver_name not in AnalogGate.SUPPORTED_SOLVERS: - raise ValueError(f"'{solver_name}' not supported.") - - duration = self.hamiltonian.emulator_ir.duration - - state = np.asarray(state).astype(np.complex128, copy=False) + state_vec, solver_name, atol, rtol, nsteps, times = self._check_args( + state_vec, solver_name, atol, rtol, nsteps, times + ) + state_data = np.asarray(state_vec.data).astype(np.complex128, copy=False) solver = ode(self.hamiltonian._ode_real_kernel_int) - solver.set_f_params(np.zeros_like(state, dtype=np.complex128)) - solver.set_initial_value(state.view(np.float64)) + solver.set_f_params(np.zeros_like(state_data, dtype=np.complex128)) + solver.set_initial_value(state_data.view(np.float64)) solver.set_integrator(solver_name, atol=atol, rtol=rtol, nsteps=nsteps) - if any(time >= duration or time < 0.0 for time in times): - raise ValueError("Times must be between 0 and duration.") - - times = [*times, duration] + state_vec_t = state_vec for time in times: - if time == 0.0: - yield state + if time == solver.t: + # if the time is the same as the current time, + # do not call the integrator, just yield state + yield state_vec_t continue + solver.integrate(time) AnalogGate._error_check(solver_name, solver.get_return_code()) + # go back to the schrodinger picture u = np.exp(-1j * time * self.hamiltonian.rydberg) - yield u * solver.y.view(np.complex128) + state_vec_t = StateVector( + u * solver.y.view(np.complex128), self.hamiltonian.space + ) + # yield the state vector in the schrodinger picture + yield state_vec_t @beartype def apply( self, - state: StateArray, + state: StateVector, solver_name: str = "dop853", atol: float = 1e-7, rtol: float = 1e-14, nsteps: int = 2_147_483_647, - times: Union[List[float], RealArray] = [], + times: Union[Sequence[float], RealArray] = (), interaction_picture: bool = False, ): if interaction_picture: - return self._apply_interation_picture( + return self._apply_interaction_picture( state, solver_name=solver_name, atol=atol, @@ -580,7 +664,7 @@ def run( nsteps: int = 2_147_483_647, interaction_picture: bool = False, project_hyperfine: bool = True, - ): + ) -> NDArray[np.uint8]: """Run the emulation with all atoms in the ground state, sampling the final state vector.""" @@ -594,8 +678,6 @@ def run( state = self.hamiltonian.space.zero_state() (result,) = self.apply(state, **options) - result /= np.linalg.norm(result) + result.normalize() - return self.hamiltonian.space.sample_state_vector( - result, shots, project_hyperfine=project_hyperfine - ) + return result.sample(shots, project_hyperfine=project_hyperfine) diff --git a/src/bloqade/ir/routine/bloqade.py b/src/bloqade/ir/routine/bloqade.py index c7deaa142..4140f0faf 100644 --- a/src/bloqade/ir/routine/bloqade.py +++ b/src/bloqade/ir/routine/bloqade.py @@ -14,6 +14,7 @@ List, NamedTuple, Iterator, + Sequence, ) from pydantic.v1.dataclasses import dataclass import dataclasses @@ -52,7 +53,7 @@ def metadata(self) -> NamedTuple: @dataclasses.dataclass(frozen=True) -class HamiltonianData: +class BloqadeEmulation: """Data class to hold the Hamiltonian and metadata for a given set of parameters""" task_data: TaskData @@ -63,6 +64,7 @@ class HamiltonianData: @property def hamiltonian(self) -> RydbergHamiltonian: + """Return the Hamiltonian object for the given task data.""" if self._hamiltonian is None: _hamiltonian = RydbergHamiltonianCodeGen(self.compile_cache).emit( self.task_data.emulator_ir @@ -72,8 +74,69 @@ def hamiltonian(self) -> RydbergHamiltonian: @property def metadata(self) -> NamedTuple: + """The metadata for the given task data.""" return self.task_data.metadata + def zero_state(self, dtype: np.dtype = np.float64) -> StateVector: + """Return the zero state for the given Hamiltonian.""" + return self.hamiltonian.space.zero_state(dtype) + + def fock_state( + self, fock_state_str: str, dtype: np.dtype = np.float64 + ) -> StateVector: + """Return the fock state for the given Hamiltonian.""" + index = self.hamiltonian.space.fock_state_to_index(fock_state_str) + data = np.zeros(self.hamiltonian.space.size, dtype=dtype) + data[index] = 1 + return StateVector(data, self.hamiltonian.space) + + def evolve( + self, + state: Optional[StateVector] = None, + solver_name: str = "dop853", + atol: float = 1e-7, + rtol: float = 1e-14, + nsteps: int = 2147483647, + times: Sequence[float] = (), + interaction_picture: bool = False, + ) -> Iterator[StateVector]: + """Evolve an initial state vector using the Hamiltonian + + Args: + state (Optional[StateVector], optional): The initial state vector to + evolve. if not provided, the zero state will be used. Defaults to None. + solver_name (str, optional): Which SciPy Solver to use. Defaults to + "dop853". + atol (float, optional): Absolute tolerance for ODE solver. Defaults + to 1e-14. + rtol (float, optional): Relative tolerance for adaptive step in + ODE solver. Defaults to 1e-7. + nsteps (int, optional): Maximum number of steps allowed per integration + step. Defaults to 2147483647. + times (Sequence[float], optional): The times to evaluate the state vector + at. Defaults to (). If not provided the state will be evaluated at + the end of the bloqade program. + interaction_picture (bool, optional): Use the interaction picture when + solving schrodinger equation. Defaults to False. + + Returns: + Iterator[StateVector]: An iterator of the state vectors at each time step. + + """ + state = self.zero_state(np.complex128) if state is None else state + + U = AnalogGate(self.hamiltonian) + + return U.apply( + state, + times=times, + solver_name=solver_name, + atol=atol, + rtol=rtol, + nsteps=nsteps, + interaction_picture=interaction_picture, + ) + @dataclass(frozen=True, config=__pydantic_dataclass_config__) class BloqadePythonRoutine(RoutineBase): @@ -103,13 +166,10 @@ def run_task(self, emulator_ir, metadata_dict): metadata = MetaData( **{k: cast_to_float(v) for k, v in metadata_dict.items()} ) - zero_state = hamiltonian.space.zero_state(np.complex128) - (register_data,) = AnalogGate(hamiltonian).apply( + (wrapped_register,) = AnalogGate(hamiltonian).apply( zero_state, **self.solver_args ) - wrapped_register = StateVector(register_data, hamiltonian.space) - return self.callback( wrapped_register, metadata, hamiltonian, *self.callback_args ) @@ -442,7 +502,7 @@ def hamiltonian( use_hyperfine: bool = False, waveform_runtime: str = "interpret", cache_matrices: bool = False, - ) -> Iterator[HamiltonianData]: + ) -> List[BloqadeEmulation]: ir_iter = self._generate_ir( args, blockade_radius, waveform_runtime, use_hyperfine @@ -453,5 +513,7 @@ def hamiltonian( else: compile_cache = None - for task_data in ir_iter: - yield HamiltonianData(task_data, compile_cache=compile_cache) + return [ + BloqadeEmulation(task_data, compile_cache=compile_cache) + for task_data in ir_iter + ] diff --git a/tests/test_emulator_interface.py b/tests/test_emulator_interface.py new file mode 100644 index 000000000..bf86ad9da --- /dev/null +++ b/tests/test_emulator_interface.py @@ -0,0 +1,98 @@ +import pytest +from bloqade import start +from bloqade.atom_arrangement import Chain +import numpy as np +from itertools import product +from math import isclose + + +def test_zero_state(): + [emu] = ( + start.add_position([(0, 0), (0, 5), (0, 10)]) + .rydberg.rabi.amplitude.uniform.constant(15.0, 4.0) + .detuning.uniform.constant(1.0, 4.0) + .bloqade.python() + .hamiltonian() + ) + + data = np.zeros(8) + data[0] = 1 + + assert np.array_equal(emu.zero_state().data, data) + + +def test_fock_state(): + [emu] = ( + start.add_position([(0, 0), (0, 5), (0, 10)]) + .rydberg.rabi.amplitude.uniform.constant(15.0, 4.0) + .detuning.uniform.constant(1.0, 4.0) + .bloqade.python() + .hamiltonian() + ) + + for idx, bitstring in enumerate(product("gr", repeat=3)): + sv = emu.fock_state("".join(bitstring[::-1])) + data = np.zeros(8) + data[idx] = 1 + assert np.array_equal(sv.data, data) + + +@pytest.mark.parametrize( + ["N", "phi", "interaction"], + product([1, 2, 3, 4, 5], np.linspace(0, 2 * np.pi, 4), [True, False]), +) +def test_solution(N: int, phi: float, interaction: bool): + rabi_freq = 2 * np.pi + program = ( + Chain(N, lattice_spacing=6) + # .rydberg.detuning.uniform.constant(1.0, 4.0) + .rydberg.rabi.amplitude.uniform.constant(rabi_freq, 4.0) + ) + + if phi != 0: + program = program.phase.uniform.constant(phi, 4) + + [emu] = program.bloqade.python().hamiltonian() + + times = np.linspace(0, 4, 101) + state_iter = emu.evolve( + times=times, interaction_picture=interaction, atol=1e-10, rtol=1e-14 + ) + + h = emu.hamiltonian.tocsr(0) + print(h.toarray()) + e, v = np.linalg.eigh(h.toarray()) + + psi0 = emu.zero_state().data + + for state, time in zip(state_iter, times): + assert str(state) + expected_data = v @ (np.diag(np.exp(-1j * e * time)) @ (v.T.conj() @ psi0)) + data = state.data + + expected_average = np.vdot(expected_data, h.dot(expected_data)).real + expected_variance = ( + np.vdot(expected_data, (h @ h).dot(expected_data)) - expected_average**2 + ).real + + print(emu.hamiltonian._apply(state.data, time) - h.dot(expected_data)) + + average = emu.hamiltonian.average(state, time=time) + variance = emu.hamiltonian.variance(state, time=time) + average_2, variance_2 = emu.hamiltonian.average_and_variance(state, time=time) + overlap = np.vdot(data, expected_data) + assert isclose(overlap, 1.0, abs_tol=1e-7), f"failed data at time {time}" + assert isclose( + average, expected_average, rel_tol=1e-7, abs_tol=1e-7 + ), f"failed average at time {time}" + assert isclose( + average_2, expected_average, rel_tol=1e-7, abs_tol=1e-7 + ), f"failed average_2 at time {time}" + assert isclose( + variance, expected_variance, rel_tol=1e-7, abs_tol=1e-7 + ), f"failed variance at time {time}" + assert isclose( + variance_2, expected_variance, rel_tol=1e-7, abs_tol=1e-7 + ), f"failed variance_2 at time {time}" + + # assert False diff --git a/tests/test_run_callback.py b/tests/test_run_callback.py index 6ba93c639..d2dab4bfc 100644 --- a/tests/test_run_callback.py +++ b/tests/test_run_callback.py @@ -30,8 +30,8 @@ def test_run_callback(): callback, multiprocessing=True, num_workers=1 ) - np.testing.assert_equal(expected_result, result_single) - np.testing.assert_equal(expected_result, result_multi) + np.testing.assert_equal(expected_result.data, result_single.data) + np.testing.assert_equal(expected_result.data, result_multi.data) def callback_exception(*args): diff --git a/tests/test_space.py b/tests/test_space.py index 047ba481b..1a559de1d 100644 --- a/tests/test_space.py +++ b/tests/test_space.py @@ -696,14 +696,14 @@ def test_zero_state(): space = Space.create(register) print(space) - assert np.all(space.zero_state() == np.array([1, 0, 0, 0, 0, 0, 0, 0])) + assert np.all(space.zero_state().data == np.array([1, 0, 0, 0, 0, 0, 0, 0])) positions = [(0, 0), (0, 1), (1, 0)] register = Register(TwoLevelAtom, positions, 1) space = Space.create(register) print(space) - assert np.all(space.zero_state() == np.array([1, 0, 0, 0, 0])) + assert np.all(space.zero_state().data == np.array([1, 0, 0, 0, 0])) @patch("bloqade.emulate.ir.space.np.random.choice")