Skip to content

Commit f58dcfb

Browse files
committed
fix(predicates): avoid pandas 2.0.3 mypy bool error with cudf mask helper
1 parent 6f5f333 commit f58dcfb

File tree

1 file changed

+11
-7
lines changed
  • graphistry/compute/predicates

1 file changed

+11
-7
lines changed

graphistry/compute/predicates/str.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,13 @@
1-
from typing import Optional, Union
1+
from typing import Any, Optional, Union
22

33
import pandas as pd
44

5+
6+
def _cudf_mask_none(result: Any, mask: Any) -> Any:
7+
result_pd = result.to_pandas().astype('object')
8+
result_pd.iloc[mask] = None
9+
return result_pd
10+
511
from .ASTPredicate import ASTPredicate
612
from graphistry.compute.typing import SeriesT
713

@@ -178,9 +184,8 @@ def __call__(self, s: SeriesT) -> SeriesT:
178184
has_na: bool = bool(s.isna().any())
179185
if has_na:
180186
# Convert to object dtype and apply mask to preserve None values
181-
na_mask: pd.Series = s.to_pandas().isna()
182-
result_pd = result.to_pandas().astype('object').where(~na_mask, None)
183-
result = cudf.from_pandas(result_pd)
187+
na_mask_arr = s.to_pandas().isna().to_numpy()
188+
result = cudf.from_pandas(_cudf_mask_none(result, na_mask_arr))
184189
else:
185190
if not self.case:
186191
s_modified = s.str.lower()
@@ -346,9 +351,8 @@ def __call__(self, s: SeriesT) -> SeriesT:
346351
has_na: bool = bool(s.isna().any())
347352
if has_na:
348353
# Convert to object dtype and apply mask to preserve None values
349-
na_mask: pd.Series = s.to_pandas().isna()
350-
result_pd = result.to_pandas().astype('object').where(~na_mask, None)
351-
result = cudf.from_pandas(result_pd)
354+
na_mask_arr = s.to_pandas().isna().to_numpy()
355+
result = cudf.from_pandas(_cudf_mask_none(result, na_mask_arr))
352356
else:
353357
if not self.case:
354358
s_modified = s.str.lower()

0 commit comments

Comments
 (0)