From 5c0c768ebaeb6ee82f82581cd79d37231b1709a3 Mon Sep 17 00:00:00 2001 From: king-p3nguin Date: Tue, 9 Apr 2024 09:01:33 +0900 Subject: [PATCH] fix dep warnings --- tensorcircuit/backends/jax_backend.py | 13 +++++++------ tensorcircuit/compiler/qiskit_compiler.py | 13 +++++++------ tensorcircuit/translation.py | 12 +++++------- tests/test_circuit.py | 4 +++- 4 files changed, 22 insertions(+), 20 deletions(-) diff --git a/tensorcircuit/backends/jax_backend.py b/tensorcircuit/backends/jax_backend.py index 581d8be7..f19c23ce 100644 --- a/tensorcircuit/backends/jax_backend.py +++ b/tensorcircuit/backends/jax_backend.py @@ -4,15 +4,16 @@ # pylint: disable=invalid-name -from functools import partial import logging import warnings +from functools import partial from typing import Any, Callable, Optional, Sequence, Tuple, Union import numpy as np -from scipy.sparse import coo_matrix import tensornetwork +from scipy.sparse import coo_matrix from tensornetwork.backends.jax import jax_backend + from .abstract_backend import ExtendedBackend logger = logging.getLogger(__name__) @@ -196,8 +197,8 @@ def __init__(self) -> None: "Jax not installed, please switch to a different " "backend or install Jax." ) - from jax.experimental import sparse import jax.scipy + from jax.experimental import sparse try: import optax @@ -419,7 +420,7 @@ def searchsorted(self, a: Tensor, v: Tensor, side: str = "left") -> Tensor: return jnp.searchsorted(a, v, side) def tree_map(self, f: Callable[..., Any], *pytrees: Any) -> Any: - return libjax.tree_map(f, *pytrees) + return libjax.tree_util.tree_map(f, *pytrees) def tree_flatten(self, pytree: Any) -> Tuple[Any, Any]: return libjax.tree_util.tree_flatten(pytree) # type: ignore @@ -630,7 +631,7 @@ def is_sparse(self, a: Tensor) -> bool: return isinstance(a, sparse.BCOO) # type: ignore def device(self, a: Tensor) -> str: - dev = a.device() + dev = a.devices() return self._dev2str(dev) def device_move(self, a: Tensor, dev: Any) -> Tensor: @@ -757,7 +758,7 @@ def wrapper( gs = list(gs) for i, (j, g) in enumerate(zip(argnums_list, gs)): if j not in vectorized_argnums: # type: ignore - gs[i] = libjax.tree_map(partial(jnp.sum, axis=0), g) + gs[i] = libjax.tree_util.tree_map(partial(jnp.sum, axis=0), g) if isinstance(argnums, int): gs = gs[0] else: diff --git a/tensorcircuit/compiler/qiskit_compiler.py b/tensorcircuit/compiler/qiskit_compiler.py index e54dd1a4..1f946359 100644 --- a/tensorcircuit/compiler/qiskit_compiler.py +++ b/tensorcircuit/compiler/qiskit_compiler.py @@ -2,8 +2,8 @@ compiler interface via qiskit """ -from typing import Any, Dict, Optional import re +from typing import Any, Dict, Optional from ..abstractcircuit import AbstractCircuit from ..circuit import Circuit @@ -71,7 +71,7 @@ def _get_positional_logical_mapping_from_qiskit(qc: Any) -> Dict[int, int]: positional_logical_mapping = {} for inst in qc.data: if inst[0].name == "measure": - positional_logical_mapping[i] = inst[1][0].index + positional_logical_mapping[i] = qc.find_bit(inst[1][0]).index i += 1 return positional_logical_mapping @@ -95,16 +95,17 @@ def _get_logical_physical_mapping_from_qiskit( for inst in qc_after.data: if inst[0].name == "measure": if qc_before is None: - logical_q = inst[2][0].index + logical_q = qc_after.find_bit(inst[2][0]).index else: for instb in qc_before.data: if ( instb[0].name == "measure" - and instb[2][0].index == inst[2][0].index + and qc_before.find_bit(instb[2][0]).index + == qc_after.find_bit(inst[2][0]).index ): - logical_q = instb[1][0].index + logical_q = qc_before.find_bit(instb[1][0]).index break - logical_physical_mapping[logical_q] = inst[1][0].index + logical_physical_mapping[logical_q] = qc_after.find_bit(inst[1][0]).index return logical_physical_mapping diff --git a/tensorcircuit/translation.py b/tensorcircuit/translation.py index 992eed39..c6ca1e0e 100644 --- a/tensorcircuit/translation.py +++ b/tensorcircuit/translation.py @@ -17,11 +17,10 @@ import sympy from qiskit import QuantumCircuit from qiskit.circuit import Parameter, ParameterExpression - from qiskit.circuit.library import XXPlusYYGate + from qiskit.circuit.exceptions import CircuitError + from qiskit.circuit.library import HamiltonianGate, UnitaryGate, XXPlusYYGate from qiskit.circuit.parametervector import ParameterVectorElement from qiskit.circuit.quantumcircuitdata import CircuitInstruction - from qiskit.extensions import UnitaryGate - from qiskit.extensions.exceptions import ExtensionError except ImportError: logger.warning( "Please first ``pip install -U qiskit`` to enable related functionality in translation module" @@ -311,9 +310,8 @@ def qir2qiskit( # Error can be presented if theta is actually complex in this procedure. exp_op = qi.Operator(unitary) index_reversed = [x for x in index[::-1]] - qiskit_circ.hamiltonian( - exp_op, time=theta, qubits=index_reversed, label=qis_name - ) + gate = HamiltonianGate(data=exp_op, time=theta, label=qis_name) + qiskit_circ.append(gate, index_reversed) elif gate_name == "multicontrol": unitary = backend.numpy(backend.convert_to_tensor(parameters["unitary"])) ctrl_str = "".join(map(str, parameters["ctrl"]))[::-1] @@ -344,7 +342,7 @@ def qir2qiskit( qop = qi.Operator(gatem) try: qiskit_circ.unitary(qop, index[::-1], label=qis_name) - except (ExtensionError, ValueError) as _: + except (CircuitError, ValueError) as _: logger.warning( "omit non unitary gate in tensorcircuit when transforming to qiskit: %s" % gate_name diff --git a/tests/test_circuit.py b/tests/test_circuit.py index 16bc4844..81f20762 100644 --- a/tests/test_circuit.py +++ b/tests/test_circuit.py @@ -1078,6 +1078,7 @@ def test_qiskit2tc(): try: import qiskit.quantum_info as qi from qiskit import QuantumCircuit + from qiskit.circuit.library import HamiltonianGate from qiskit.circuit.library.standard_gates import MCXGate, SwapGate from tensorcircuit.translation import perm_matrix @@ -1090,7 +1091,8 @@ def test_qiskit2tc(): zz = np.array([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]]) exp_op = qi.Operator(zz) for i in range(n): - qisc.hamiltonian(exp_op, time=np.random.uniform(), qubits=[i, (i + 1) % n]) + gate = HamiltonianGate(exp_op, time=np.random.uniform()) + qisc.append(gate, [i, (i + 1) % n]) qisc.fredkin(1, 2, 3) qisc.cswap(1, 2, 3) qisc.swap(0, 1)