diff --git a/src/qutip_tensornetwork/core/data/tensor_train/tensor_train.py b/src/qutip_tensornetwork/core/data/tensor_train/tensor_train.py index e9a2132..f062b76 100644 --- a/src/qutip_tensornetwork/core/data/tensor_train/tensor_train.py +++ b/src/qutip_tensornetwork/core/data/tensor_train/tensor_train.py @@ -2,7 +2,9 @@ import tensornetwork as tn from itertools import chain -__all__ = ["FiniteTT", ] +__all__ = [ + "FiniteTT", +] class FiniteTT(Network): @@ -197,8 +199,8 @@ def _to_tt_format(self): left_edges += in_edges[i] left_edges += lbond - right_edges = out_edges[i + 1:] - right_edges += in_edges[i + 1:] + right_edges = out_edges[i + 1 :] + right_edges += in_edges[i + 1 :] # We flatten the right edges as it is a list of lists right_edges = list(chain(*right_edges)) @@ -235,6 +237,7 @@ def _to_tt_format(self): self._nodes = set(nodes) + def _check_shape(nodes): """Check that the nodes have the appropriate shape for the `from_node_list` method."""