1
1
from typing import TYPE_CHECKING
2
+ from inspect import getmodule
2
3
from logging import getLogger
3
4
import pandas as pd
4
5
@@ -40,6 +41,8 @@ def build_index(self, angular=False, n_trees=None):
40
41
self .assert_fitted ()
41
42
self .assert_features_line_up_with_nodes ()
42
43
X = self ._get_feature ("nodes" )
44
+ if 'cudf' in str (getmodule (X )):
45
+ X = X .to_pandas ()
43
46
self .search_index = FaissVectorSearch (
44
47
X .values
45
48
) # self._build_search_index(X, angular, n_trees, faiss=False)
@@ -48,6 +51,10 @@ def _query_from_dataframe(self, qdf: pd.DataFrame, top_n: int, thresh: float):
48
51
# Use the loaded featurizers to transform the dataframe
49
52
vect , _ = self .transform (qdf , None , kind = "nodes" , return_graph = False )
50
53
54
+ nodes = self ._nodes
55
+ if 'cudf' in str (getmodule (nodes )):
56
+ nodes = nodes .to_pandas ()
57
+
51
58
results = self .search_index .search_df (vect , self ._nodes , top_n )
52
59
results = results .query (f"{ DISTANCE } < { thresh } " )
53
60
@@ -210,15 +217,40 @@ def search_graph(
210
217
# print('shape of edges', edf.shape)
211
218
rdf = df = res ._nodes
212
219
# print('shape of nodes', rdf.shape)
220
+
221
+ if 'cudf' in str (getmodule (edges )):
222
+ import cudf
223
+
224
+ if not isinstance (rdf , cudf .DataFrame ):
225
+ rdf = cudf .from_pandas (rdf )
226
+ df = rdf
227
+
228
+ concat = cudf .concat
229
+ cudf_coercion = True
230
+ else :
231
+ concat = pd .concat
232
+ cudf_coercion = False
233
+
213
234
node = res ._node
214
235
indices = rdf [node ]
236
+ if cudf_coercion :
237
+ import cudf
238
+ if not isinstance (indices , cudf .Series ):
239
+ indices = cudf .Series .from_pandas (indices )
215
240
src = res ._source
216
241
dst = res ._destination
217
242
if query != "" :
218
243
# run a real query, else return entire graph
219
244
rdf , _ = res .search (query , thresh = thresh , fuzzy = True , top_n = top_n )
220
245
if not rdf .empty :
246
+ if cudf_coercion :
247
+ import cudf
248
+ #if not isinstance(indices, cudf.Series):
249
+ # indices = cudf.Series.from_pandas(indices)
250
+ if not isinstance (rdf , cudf .DataFrame ):
251
+ rdf = cudf .from_pandas (rdf )
221
252
indices = rdf [node ]
253
+
222
254
# now get edges from indices
223
255
if broader : # this will make a broader graph, finding NN in src OR dst
224
256
edges = edf [(edf [src ].isin (indices )) | (edf [dst ].isin (indices ))]
@@ -236,19 +268,35 @@ def search_graph(
236
268
except : # for explicit edges
237
269
pass
238
270
239
- found_indices = pd .concat ([edges [src ], edges [dst ], indices ], axis = 0 ).unique ()
271
+ #logger.info('type edges=%s, indices=%s', type(edges), type(indices))
272
+ #raise ValueError(f'stop here: {type(edges)}, {type(indices)}')
273
+
274
+ found_indices = concat ([edges [src ], edges [dst ], indices ], axis = 0 ).unique ()
240
275
emb = None
276
+ node_feats = res ._node_features
277
+ if cudf_coercion :
278
+ import cudf
279
+ if not isinstance (node_feats , cudf .DataFrame ):
280
+ node_feats = cudf .from_pandas (node_feats )
281
+
282
+ node_emb = res ._node_embedding
283
+ if cudf_coercion and res ._umap is not None :
284
+ import cudf
285
+ node_emb = res ._node_embedding
286
+ if not isinstance (node_emb , cudf .DataFrame ):
287
+ node_emb = cudf .from_pandas (node_emb )
288
+
241
289
try :
242
290
tdf = rdf .iloc [found_indices ]
243
- feats = res . _node_features .iloc [found_indices ] # type: ignore
291
+ feats = node_feats .iloc [found_indices ] # type: ignore
244
292
if res ._umap is not None :
245
- emb = res . _node_embedding .iloc [found_indices ] # type: ignore
293
+ emb = node_emb .iloc [found_indices ] # type: ignore
246
294
except Exception : # for explicit relabeled nodes
247
295
#logger.exception(e)
248
296
tdf = rdf [df [node ].isin (found_indices )]
249
- feats = res . _node_features .loc [tdf .index ] # type: ignore
297
+ feats = node_feats .loc [tdf .index ] # type: ignore
250
298
if res ._umap is not None :
251
- emb = res . _node_embedding [df [node ].isin (found_indices )] # type: ignore
299
+ emb = node_emb [df [node ].isin (found_indices )] # type: ignore
252
300
logger .info (f" - Returning edge dataframe of size { edges .shape [0 ]} " )
253
301
# get all the unique nodes
254
302
logger .info (
0 commit comments