Skip to content

Commit

Permalink
fix dep warnings
Browse files Browse the repository at this point in the history
  • Loading branch information
king-p3nguin committed Apr 9, 2024
1 parent cff2344 commit 5c0c768
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 20 deletions.
13 changes: 7 additions & 6 deletions tensorcircuit/backends/jax_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,16 @@

# pylint: disable=invalid-name

from functools import partial
import logging
import warnings
from functools import partial
from typing import Any, Callable, Optional, Sequence, Tuple, Union

import numpy as np
from scipy.sparse import coo_matrix
import tensornetwork
from scipy.sparse import coo_matrix
from tensornetwork.backends.jax import jax_backend

from .abstract_backend import ExtendedBackend

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -196,8 +197,8 @@ def __init__(self) -> None:
"Jax not installed, please switch to a different "
"backend or install Jax."
)
from jax.experimental import sparse
import jax.scipy
from jax.experimental import sparse

try:
import optax
Expand Down Expand Up @@ -419,7 +420,7 @@ def searchsorted(self, a: Tensor, v: Tensor, side: str = "left") -> Tensor:
return jnp.searchsorted(a, v, side)

def tree_map(self, f: Callable[..., Any], *pytrees: Any) -> Any:
return libjax.tree_map(f, *pytrees)
return libjax.tree_util.tree_map(f, *pytrees)

def tree_flatten(self, pytree: Any) -> Tuple[Any, Any]:
return libjax.tree_util.tree_flatten(pytree) # type: ignore
Expand Down Expand Up @@ -630,7 +631,7 @@ def is_sparse(self, a: Tensor) -> bool:
return isinstance(a, sparse.BCOO) # type: ignore

def device(self, a: Tensor) -> str:
dev = a.device()
dev = a.devices()
return self._dev2str(dev)

def device_move(self, a: Tensor, dev: Any) -> Tensor:
Expand Down Expand Up @@ -757,7 +758,7 @@ def wrapper(
gs = list(gs)
for i, (j, g) in enumerate(zip(argnums_list, gs)):
if j not in vectorized_argnums: # type: ignore
gs[i] = libjax.tree_map(partial(jnp.sum, axis=0), g)
gs[i] = libjax.tree_util.tree_map(partial(jnp.sum, axis=0), g)
if isinstance(argnums, int):
gs = gs[0]
else:
Expand Down
13 changes: 7 additions & 6 deletions tensorcircuit/compiler/qiskit_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
compiler interface via qiskit
"""

from typing import Any, Dict, Optional
import re
from typing import Any, Dict, Optional

from ..abstractcircuit import AbstractCircuit
from ..circuit import Circuit
Expand Down Expand Up @@ -71,7 +71,7 @@ def _get_positional_logical_mapping_from_qiskit(qc: Any) -> Dict[int, int]:
positional_logical_mapping = {}
for inst in qc.data:
if inst[0].name == "measure":
positional_logical_mapping[i] = inst[1][0].index
positional_logical_mapping[i] = qc.find_bit(inst[1][0]).index
i += 1

return positional_logical_mapping
Expand All @@ -95,16 +95,17 @@ def _get_logical_physical_mapping_from_qiskit(
for inst in qc_after.data:
if inst[0].name == "measure":
if qc_before is None:
logical_q = inst[2][0].index
logical_q = qc_after.find_bit(inst[2][0]).index
else:
for instb in qc_before.data:
if (
instb[0].name == "measure"
and instb[2][0].index == inst[2][0].index
and qc_before.find_bit(instb[2][0]).index
== qc_after.find_bit(inst[2][0]).index
):
logical_q = instb[1][0].index
logical_q = qc_before.find_bit(instb[1][0]).index
break
logical_physical_mapping[logical_q] = inst[1][0].index
logical_physical_mapping[logical_q] = qc_after.find_bit(inst[1][0]).index
return logical_physical_mapping


Expand Down
12 changes: 5 additions & 7 deletions tensorcircuit/translation.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,10 @@
import sympy
from qiskit import QuantumCircuit
from qiskit.circuit import Parameter, ParameterExpression
from qiskit.circuit.library import XXPlusYYGate
from qiskit.circuit.exceptions import CircuitError
from qiskit.circuit.library import HamiltonianGate, UnitaryGate, XXPlusYYGate
from qiskit.circuit.parametervector import ParameterVectorElement
from qiskit.circuit.quantumcircuitdata import CircuitInstruction
from qiskit.extensions import UnitaryGate
from qiskit.extensions.exceptions import ExtensionError
except ImportError:
logger.warning(
"Please first ``pip install -U qiskit`` to enable related functionality in translation module"
Expand Down Expand Up @@ -311,9 +310,8 @@ def qir2qiskit(
# Error can be presented if theta is actually complex in this procedure.
exp_op = qi.Operator(unitary)
index_reversed = [x for x in index[::-1]]
qiskit_circ.hamiltonian(
exp_op, time=theta, qubits=index_reversed, label=qis_name
)
gate = HamiltonianGate(data=exp_op, time=theta, label=qis_name)
qiskit_circ.append(gate, index_reversed)
elif gate_name == "multicontrol":
unitary = backend.numpy(backend.convert_to_tensor(parameters["unitary"]))
ctrl_str = "".join(map(str, parameters["ctrl"]))[::-1]
Expand Down Expand Up @@ -344,7 +342,7 @@ def qir2qiskit(
qop = qi.Operator(gatem)
try:
qiskit_circ.unitary(qop, index[::-1], label=qis_name)
except (ExtensionError, ValueError) as _:
except (CircuitError, ValueError) as _:
logger.warning(
"omit non unitary gate in tensorcircuit when transforming to qiskit: %s"
% gate_name
Expand Down
4 changes: 3 additions & 1 deletion tests/test_circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -1078,6 +1078,7 @@ def test_qiskit2tc():
try:
import qiskit.quantum_info as qi
from qiskit import QuantumCircuit
from qiskit.circuit.library import HamiltonianGate
from qiskit.circuit.library.standard_gates import MCXGate, SwapGate

from tensorcircuit.translation import perm_matrix
Expand All @@ -1090,7 +1091,8 @@ def test_qiskit2tc():
zz = np.array([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]])
exp_op = qi.Operator(zz)
for i in range(n):
qisc.hamiltonian(exp_op, time=np.random.uniform(), qubits=[i, (i + 1) % n])
gate = HamiltonianGate(exp_op, time=np.random.uniform())
qisc.append(gate, [i, (i + 1) % n])
qisc.fredkin(1, 2, 3)
qisc.cswap(1, 2, 3)
qisc.swap(0, 1)
Expand Down

0 comments on commit 5c0c768

Please sign in to comment.