Skip to content

Commit afff91d

Browse files
committed
fix(cudf): more skrub upgrade fixes
1 parent 9a3a886 commit afff91d

File tree

3 files changed

+73
-10
lines changed

3 files changed

+73
-10
lines changed

graphistry/feature_utils.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -960,10 +960,20 @@ def process_dirty_dataframes(
960960

961961
logger.info(":: Encoding DataFrame might take a few minutes ------")
962962

963+
if 'cudf' in str(getmodule(ndf)):
964+
import cudf
965+
assert isinstance(ndf, cudf.DataFrame)
966+
logger.debug('Coercing cudf to pandas for skrub, with feature_engine=%s', feature_engine)
967+
ndf_passthrough = ndf.to_pandas()
968+
coercing_to_pandas = True
969+
else:
970+
ndf_passthrough = ndf
971+
coercing_to_pandas = False
972+
963973
try:
964-
X_enc = data_encoder.fit_transform(ndf, y)
974+
X_enc = data_encoder.fit_transform(ndf_passthrough, y)
965975
except TypeError:
966-
nndf = ndf.copy()
976+
nndf = ndf_passthrough.copy()
967977
object_columns = nndf.select_dtypes(include=['object']).columns
968978
nndf[object_columns] = nndf[object_columns].astype(str)
969979
X_enc = data_encoder.fit_transform(nndf, y)
@@ -990,9 +1000,14 @@ def process_dirty_dataframes(
9901000
data_encoder.get_feature_names_out = callThrough(features_transformed)
9911001

9921002
X_enc = pd.DataFrame(
993-
X_enc, columns=features_transformed, index=ndf.index
1003+
X_enc, columns=features_transformed, index=ndf_passthrough.index
9941004
)
9951005
X_enc = X_enc.fillna(0.0)
1006+
1007+
if coercing_to_pandas:
1008+
import cudf
1009+
X_enc = cudf.DataFrame.from_pandas(X_enc)
1010+
9961011
elif not all_numeric and (not has_skrub or feature_engine in ["pandas", "none"]):
9971012
numeric_ndf = ndf.select_dtypes(include=[np.number]) # type: ignore
9981013
logger.warning("-*-*- DataFrame is not numeric and no skrub, dropping non-numeric")

graphistry/tests/layout/ring/test_ring_categorical.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ def test_ring_cudf(self):
131131
rs = (g._nodes['x'] * g._nodes['x'] + g._nodes['y'] * g._nodes['y']).apply(np.sqrt)
132132
assert rs.min() == 500
133133
assert rs.max() == 800
134-
assert len(g._complex_encodings and g._complex_encodings['node_encodings']['default']['pointAxisEncoding']['rows']) == 5
134+
assert len(g._complex_encodings and g._complex_encodings['node_encodings']['default']['pointAxisEncoding']['rows']) == 4
135135
for i, row in enumerate(g._complex_encodings['node_encodings']['default']['pointAxisEncoding']['rows']):
136136
assert row['r'] == 500 + 100 * i
137-
assert row['label'] == str(2 + 2 * i)
137+
assert row['label'] == ['a', 'bb', 'cc', 'dd'][i]

graphistry/text_utils.py

Lines changed: 53 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from typing import TYPE_CHECKING
2+
from inspect import getmodule
23
from logging import getLogger
34
import pandas as pd
45

@@ -40,6 +41,8 @@ def build_index(self, angular=False, n_trees=None):
4041
self.assert_fitted()
4142
self.assert_features_line_up_with_nodes()
4243
X = self._get_feature("nodes")
44+
if 'cudf' in str(getmodule(X)):
45+
X = X.to_pandas()
4346
self.search_index = FaissVectorSearch(
4447
X.values
4548
) # self._build_search_index(X, angular, n_trees, faiss=False)
@@ -48,6 +51,10 @@ def _query_from_dataframe(self, qdf: pd.DataFrame, top_n: int, thresh: float):
4851
# Use the loaded featurizers to transform the dataframe
4952
vect, _ = self.transform(qdf, None, kind="nodes", return_graph=False)
5053

54+
nodes = self._nodes
55+
if 'cudf' in str(getmodule(nodes)):
56+
nodes = nodes.to_pandas()
57+
5158
results = self.search_index.search_df(vect, self._nodes, top_n)
5259
results = results.query(f"{DISTANCE} < {thresh}")
5360

@@ -210,15 +217,40 @@ def search_graph(
210217
# print('shape of edges', edf.shape)
211218
rdf = df = res._nodes
212219
# print('shape of nodes', rdf.shape)
220+
221+
if 'cudf' in str(getmodule(edges)):
222+
import cudf
223+
224+
if not isinstance(rdf, cudf.DataFrame):
225+
rdf = cudf.from_pandas(rdf)
226+
df = rdf
227+
228+
concat = cudf.concat
229+
cudf_coercion = True
230+
else:
231+
concat = pd.concat
232+
cudf_coercion = False
233+
213234
node = res._node
214235
indices = rdf[node]
236+
if cudf_coercion:
237+
import cudf
238+
if not isinstance(indices, cudf.Series):
239+
indices = cudf.Series.from_pandas(indices)
215240
src = res._source
216241
dst = res._destination
217242
if query != "":
218243
# run a real query, else return entire graph
219244
rdf, _ = res.search(query, thresh=thresh, fuzzy=True, top_n=top_n)
220245
if not rdf.empty:
246+
if cudf_coercion:
247+
import cudf
248+
#if not isinstance(indices, cudf.Series):
249+
# indices = cudf.Series.from_pandas(indices)
250+
if not isinstance(rdf, cudf.DataFrame):
251+
rdf = cudf.from_pandas(rdf)
221252
indices = rdf[node]
253+
222254
# now get edges from indices
223255
if broader: # this will make a broader graph, finding NN in src OR dst
224256
edges = edf[(edf[src].isin(indices)) | (edf[dst].isin(indices))]
@@ -236,19 +268,35 @@ def search_graph(
236268
except: # for explicit edges
237269
pass
238270

239-
found_indices = pd.concat([edges[src], edges[dst], indices], axis=0).unique()
271+
#logger.info('type edges=%s, indices=%s', type(edges), type(indices))
272+
#raise ValueError(f'stop here: {type(edges)}, {type(indices)}')
273+
274+
found_indices = concat([edges[src], edges[dst], indices], axis=0).unique()
240275
emb = None
276+
node_feats = res._node_features
277+
if cudf_coercion:
278+
import cudf
279+
if not isinstance(node_feats, cudf.DataFrame):
280+
node_feats = cudf.from_pandas(node_feats)
281+
282+
node_emb = res._node_embedding
283+
if cudf_coercion and res._umap is not None:
284+
import cudf
285+
node_emb = res._node_embedding
286+
if not isinstance(node_emb, cudf.DataFrame):
287+
node_emb = cudf.from_pandas(node_emb)
288+
241289
try:
242290
tdf = rdf.iloc[found_indices]
243-
feats = res._node_features.iloc[found_indices] # type: ignore
291+
feats = node_feats.iloc[found_indices] # type: ignore
244292
if res._umap is not None:
245-
emb = res._node_embedding.iloc[found_indices] # type: ignore
293+
emb = node_emb.iloc[found_indices] # type: ignore
246294
except Exception: # for explicit relabeled nodes
247295
#logger.exception(e)
248296
tdf = rdf[df[node].isin(found_indices)]
249-
feats = res._node_features.loc[tdf.index] # type: ignore
297+
feats = node_feats.loc[tdf.index] # type: ignore
250298
if res._umap is not None:
251-
emb = res._node_embedding[df[node].isin(found_indices)] # type: ignore
299+
emb = node_emb[df[node].isin(found_indices)] # type: ignore
252300
logger.info(f" - Returning edge dataframe of size {edges.shape[0]}")
253301
# get all the unique nodes
254302
logger.info(

0 commit comments

Comments
 (0)