From 8c380c8105a1d0c82d8c679c56091ed097b9c9a9 Mon Sep 17 00:00:00 2001 From: Franklin <41602287+fcogidi@users.noreply.github.com> Date: Mon, 12 Feb 2024 11:14:42 -0500 Subject: [PATCH] fix doctest error --- cyclops/utils/index.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/cyclops/utils/index.py b/cyclops/utils/index.py index 2e3941a02..3b281ec7a 100644 --- a/cyclops/utils/index.py +++ b/cyclops/utils/index.py @@ -3,6 +3,7 @@ from typing import Any, List, Optional, Sequence, Tuple, Union import numpy as np +import numpy.typing as npt def index_axis(ind: int, axis: int, shape: Tuple[int, ...]) -> Tuple[Any, ...]: @@ -33,9 +34,9 @@ def index_axis(ind: int, axis: int, shape: Tuple[int, ...]) -> Tuple[Any, ...]: def take_indices( - data: np.typing.NDArray[Any], - indexes: Sequence[Optional[Union[Sequence[int], np.typing.NDArray[Any]]]], -) -> np.typing.NDArray[Any]: + data: npt.NDArray[Any], + indexes: Sequence[Optional[Union[Sequence[int], npt.NDArray[Any]]]], +) -> npt.NDArray[Any]: """Index array by specifying the indices to take on each axis. Parameters @@ -69,10 +70,10 @@ def take_indices( def take_indices_over_axis( - data: np.typing.NDArray[Any], + data: npt.NDArray[Any], axis: int, - index: Union[np.typing.NDArray[Any], Sequence[int]], -) -> np.typing.NDArray[Any]: + index: Union[npt.NDArray[Any], Sequence[int]], +) -> npt.NDArray[Any]: """Take indices along an axis. Parameters