Skip to content

Commit 6a7c381

Browse files
type Index view and drop (#1518)
* type Index view and drop * note mypy vs pyright discrepancy * old numpy * more precise comments, laxer arg types, return 1d np array Co-authored-by: Yi-Fan Wang <cmp0xff@users.noreply.github.com> * fixup * fixup * Update tests/indexes/test_indexes.py Co-authored-by: Yi-Fan Wang <cmp0xff@users.noreply.github.com> * fixup syntax * use np.ndarray in test * use np.ndarray in test * use np.ndarray in test * use np.ndarray in test --------- Co-authored-by: Yi-Fan Wang <cmp0xff@users.noreply.github.com>
1 parent b983c83 commit 6a7c381

File tree

4 files changed

+59
-4
lines changed

4 files changed

+59
-4
lines changed

pandas-stubs/_typing.pyi

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -950,6 +950,8 @@ np_1darray_dt: TypeAlias = np_1darray[np.datetime64]
950950
np_1darray_td: TypeAlias = np_1darray[np.timedelta64]
951951
np_2darray: TypeAlias = np.ndarray[tuple[int, int], np.dtype[GenericT]]
952952

953+
NDArrayT = TypeVar("NDArrayT", bound=np.ndarray)
954+
953955
DtypeNp = TypeVar("DtypeNp", bound=np.dtype[np.generic])
954956
KeysArgType: TypeAlias = Any
955957
ListLikeT = TypeVar("ListLikeT", bound=ListLike)

pandas-stubs/core/indexes/base.pyi

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,11 @@ from pandas._typing import (
102102
Level,
103103
MaskType,
104104
NaPosition,
105+
NDArrayT,
105106
NumpyFloatNot16DtypeArg,
107+
NumpyNotTimeDtypeArg,
108+
NumpyTimedeltaDtypeArg,
109+
NumpyTimestampDtypeArg,
106110
PandasAstypeFloatDtypeArg,
107111
PandasFloatDtypeArg,
108112
PyArrowFloatDtypeArg,
@@ -374,7 +378,15 @@ class Index(IndexOpsMixin[S1], ElementOpsMixin[S1]):
374378
def dtype(self) -> DtypeObj: ...
375379
@final
376380
def ravel(self, order: _str = "C") -> Self: ...
377-
def view(self, cls=...): ...
381+
@overload
382+
def view(self, cls: None = None) -> Self: ...
383+
@overload
384+
def view(self, cls: type[NDArrayT]) -> NDArrayT: ...
385+
@overload
386+
def view(
387+
self,
388+
cls: NumpyNotTimeDtypeArg | NumpyTimedeltaDtypeArg | NumpyTimestampDtypeArg,
389+
) -> np_1darray: ...
378390
@overload
379391
def astype(
380392
self,
@@ -596,7 +608,11 @@ class Index(IndexOpsMixin[S1], ElementOpsMixin[S1]):
596608
def insert(self, loc: int, item: S1) -> Self: ...
597609
@overload
598610
def insert(self, loc: int, item: object) -> Index: ...
599-
def drop(self, labels, errors: IgnoreRaise = "raise") -> Self: ...
611+
def drop(
612+
self,
613+
labels: IndexOpsMixin | np_ndarray | Iterable[Hashable],
614+
errors: IgnoreRaise = "raise",
615+
) -> Self: ...
600616
@property
601617
def shape(self) -> tuple[int, ...]: ...
602618
# Extra methods from old stubs

pyproject.toml

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -237,8 +237,12 @@ ignore = [
237237
# TODO: remove when _libs is fully typed
238238
"ANN001", "ANN201", "ANN204", "ANN206",
239239
]
240-
"*base.pyi" = [
241-
# TODO: remove when base.pyi's are fully typed
240+
"*core/base.pyi" = [
241+
# TODO: remove when core/base.pyi is fully typed
242+
"ANN001", "ANN201", "ANN204", "ANN206",
243+
]
244+
"*excel/_base.pyi" = [
245+
# TODO: remove when excel/_base.pyi is fully typed
242246
"ANN001", "ANN201", "ANN204", "ANN206",
243247
]
244248
"scripts/*" = [

tests/indexes/test_indexes.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from collections.abc import Hashable
44
import datetime as dt
5+
import sys
56
from typing import (
67
Any,
78
cast,
@@ -1667,3 +1668,35 @@ def test_index_slice_locs() -> None:
16671668
start, end = idx.slice_locs(0, 1)
16681669
check(assert_type(start, np.intp | int), np.integer)
16691670
check(assert_type(end, np.intp | int), int)
1671+
1672+
1673+
def test_index_view() -> None:
1674+
ind = pd.Index([1, 2])
1675+
check(assert_type(ind.view("int64"), np_1darray), np_1darray)
1676+
check(assert_type(ind.view(), "pd.Index[int]"), pd.Index)
1677+
if sys.version_info >= (3, 11):
1678+
# mypy and pyright differ here in what they report:
1679+
# - mypy: ndarray[Any, Any]"
1680+
# - pyright: ndarray[tuple[Any, ...], dtype[Any]]
1681+
check(assert_type(ind.view(np.ndarray), np.ndarray), np.ndarray) # type: ignore[assert-type]
1682+
else:
1683+
check(assert_type(ind.view(np.ndarray), np.ndarray), np.ndarray)
1684+
1685+
class MyArray(np.ndarray): ...
1686+
1687+
check(assert_type(ind.view(MyArray), MyArray), MyArray)
1688+
1689+
1690+
def test_index_drop() -> None:
1691+
ind = pd.Index([1, 2, 3])
1692+
check(assert_type(ind.drop([1, 2]), "pd.Index[int]"), pd.Index, np.integer)
1693+
check(
1694+
assert_type(ind.drop(pd.Index([1, 2])), "pd.Index[int]"), pd.Index, np.integer
1695+
)
1696+
check(
1697+
assert_type(ind.drop(pd.Series([1, 2])), "pd.Index[int]"), pd.Index, np.integer
1698+
)
1699+
check(
1700+
assert_type(ind.drop(np.array([1, 2])), "pd.Index[int]"), pd.Index, np.integer
1701+
)
1702+
check(assert_type(ind.drop(iter([1, 2])), "pd.Index[int]"), pd.Index, np.integer)

0 commit comments

Comments
 (0)