@@ -281,22 +281,54 @@ def _add_embeddings(row, embeddings, info_msgs):
281
281
return row
282
282
283
283
284
- def _get_text_content (row ):
284
+ def _get_pandas_text_content (row ):
285
285
"""
286
286
A pandas UDF used to select extracted text content to be used to create embeddings.
287
287
"""
288
288
289
289
return row ["content" ]
290
290
291
291
292
- def _get_table_content (row ):
292
+ def _get_pandas_table_content (row ):
293
293
"""
294
294
A pandas UDF used to select extracted table/chart content to be used to create embeddings.
295
295
"""
296
296
297
297
return row ["table_metadata" ]["table_content" ]
298
298
299
299
300
+ def _get_pandas_image_content (row ):
301
+ """
302
+ A pandas UDF used to select extracted image captions to be used to create embeddings.
303
+ """
304
+
305
+ return row ["image_metadata" ]["caption" ]
306
+
307
+
308
+ def _get_cudf_text_content (df : cudf .DataFrame ):
309
+ """
310
+ A cuDF UDF used to select extracted text content to be used to create embeddings.
311
+ """
312
+
313
+ return df .struct .field ("content" )
314
+
315
+
316
+ def _get_cudf_table_content (df : cudf .DataFrame ):
317
+ """
318
+ A cuDF UDF used to select extracted table/chart content to be used to create embeddings.
319
+ """
320
+
321
+ return df .struct .field ("table_metadata" ).struct .field ("table_content" )
322
+
323
+
324
+ def _get_cudf_image_content (df : cudf .DataFrame ):
325
+ """
326
+ A cuDF UDF used to select extracted image captions to be used to create embeddings.
327
+ """
328
+
329
+ return df .struct .field ("image_metadata" ).struct .field ("caption" )
330
+
331
+
300
332
def _batch_generator (iterable : Iterable , batch_size = 10 ):
301
333
"""
302
334
A generator to yield batches of size `batch_size` from an interable.
@@ -349,7 +381,6 @@ def _generate_batches(prompts: List[str], batch_size: int = 100):
349
381
350
382
def _generate_embeddings (
351
383
ctrl_msg : ControlMessage ,
352
- content_type : ContentTypeEnum ,
353
384
event_loop : asyncio .SelectorEventLoop ,
354
385
batch_size : int ,
355
386
api_key : str ,
@@ -361,8 +392,10 @@ def _generate_embeddings(
361
392
filter_errors : bool ,
362
393
):
363
394
"""
364
- A function to generate embeddings for the supplied `ContentTypeEnum`. The `ContentTypeEnum` will
365
- drive filtering criteria used to select rows of data to enrich with embeddings.
395
+ A function to generate text embeddings for supported content types (TEXT, STRUCTURED, IMAGE).
396
+
397
+ This function dynamically selects the appropriate metadata field based on content type and
398
+ calculates embeddings using the NIM embedding service. AUDIO and VIDEO types are stubbed and skipped.
366
399
367
400
Parameters
368
401
----------
@@ -403,53 +436,71 @@ def _generate_embeddings(
403
436
content_mask : cudf.Series
404
437
A boolean mask representing rows filtered to calculate embeddings.
405
438
"""
439
+ cudf_content_extractor = {
440
+ ContentTypeEnum .TEXT : _get_cudf_text_content ,
441
+ ContentTypeEnum .STRUCTURED : _get_cudf_table_content ,
442
+ ContentTypeEnum .IMAGE : _get_cudf_image_content ,
443
+ ContentTypeEnum .AUDIO : lambda _ : None , # Not supported yet.
444
+ ContentTypeEnum .VIDEO : lambda _ : None , # Not supported yet.
445
+ }
446
+ pandas_content_extractor = {
447
+ ContentTypeEnum .TEXT : _get_pandas_text_content ,
448
+ ContentTypeEnum .STRUCTURED : _get_pandas_table_content ,
449
+ ContentTypeEnum .IMAGE : _get_pandas_image_content ,
450
+ ContentTypeEnum .AUDIO : lambda _ : None , # Not supported yet.
451
+ ContentTypeEnum .VIDEO : lambda _ : None , # Not supported yet.
452
+ }
453
+
454
+ logger .debug ("Generating text embeddings for supported content types: TEXT, STRUCTURED, IMAGE." )
455
+
456
+ embedding_dataframes = []
457
+ content_masks = []
406
458
407
459
with ctrl_msg .payload ().mutable_dataframe () as mdf :
408
460
if mdf .empty :
409
- return None , None
410
-
411
- # generate table text mask
412
- if content_type == ContentTypeEnum .TEXT :
413
- content_mask = (mdf ["document_type" ] == content_type .value ) & (
414
- mdf ["metadata" ].struct .field ("content" ) != ""
415
- ).fillna (False )
416
- content_getter = _get_text_content
417
- elif content_type == ContentTypeEnum .STRUCTURED :
418
- table_mask = mdf ["document_type" ] == content_type .value
419
- if not table_mask .any ():
420
- return None , None
421
- content_mask = table_mask & (
422
- mdf ["metadata" ].struct .field ("table_metadata" ).struct .field ("table_content" ) != ""
423
- ).fillna (False )
424
- content_getter = _get_table_content
425
-
426
- # exit if matches found
427
- if not content_mask .any ():
428
- return None , None
429
-
430
- df_text = mdf .loc [content_mask ].to_pandas ().reset_index (drop = True )
431
- # get text list
432
- filtered_text = df_text ["metadata" ].apply (content_getter )
433
- # calculate embeddings
434
- filtered_text_batches = _generate_batches (filtered_text .tolist (), batch_size )
435
- text_embeddings = _async_runner (
436
- filtered_text_batches ,
437
- api_key ,
438
- embedding_nim_endpoint ,
439
- embedding_model ,
440
- encoding_format ,
441
- input_type ,
442
- truncate ,
443
- event_loop ,
444
- filter_errors ,
445
- )
446
- # update embeddings in metadata
447
- df_text [["metadata" , "document_type" , "_contains_embeddings" ]] = df_text .apply (
448
- _add_embeddings , ** text_embeddings , axis = 1
449
- )[["metadata" , "document_type" , "_contains_embeddings" ]]
450
- df_text ["_content" ] = filtered_text
461
+ return ctrl_msg
462
+
463
+ for content_type , content_getter in pandas_content_extractor .items ():
464
+ if not content_getter :
465
+ logger .debug (f"Skipping unsupported content type: { content_type } " )
466
+ continue
467
+
468
+ content_mask = mdf ["document_type" ] == content_type .value
469
+ if not content_mask .any ():
470
+ continue
471
+
472
+ cudf_content_getter = cudf_content_extractor [content_type ]
473
+ content_mask = (content_mask & (cudf_content_getter (mdf ["metadata" ]) != "" )).fillna (False )
474
+ if not content_mask .any ():
475
+ continue
476
+
477
+ df_content = mdf .loc [content_mask ].to_pandas ().reset_index (drop = True )
478
+ filtered_content = df_content ["metadata" ].apply (content_getter )
479
+ # calculate embeddings
480
+ filtered_content_batches = _generate_batches (filtered_content .tolist (), batch_size )
481
+ content_embeddings = _async_runner (
482
+ filtered_content_batches ,
483
+ api_key ,
484
+ embedding_nim_endpoint ,
485
+ embedding_model ,
486
+ encoding_format ,
487
+ input_type ,
488
+ truncate ,
489
+ event_loop ,
490
+ filter_errors ,
491
+ )
492
+ # update embeddings in metadata
493
+ df_content [["metadata" , "document_type" , "_contains_embeddings" ]] = df_content .apply (
494
+ _add_embeddings , ** content_embeddings , axis = 1
495
+ )[["metadata" , "document_type" , "_contains_embeddings" ]]
496
+ df_content ["_content" ] = filtered_content
497
+
498
+ embedding_dataframes .append (df_content )
499
+ content_masks .append (content_mask )
500
+
501
+ message = _concatenate_extractions (ctrl_msg , embedding_dataframes , content_masks )
451
502
452
- return df_text , content_mask
503
+ return message
453
504
454
505
455
506
def _concatenate_extractions (ctrl_msg : ControlMessage , dataframes : List [pd .DataFrame ], masks : List [cudf .Series ]):
@@ -493,8 +544,8 @@ def _concatenate_extractions(ctrl_msg: ControlMessage, dataframes: List[pd.DataF
493
544
@register_module (MODULE_NAME , MODULE_NAMESPACE )
494
545
def _embed_extractions (builder : mrc .Builder ):
495
546
"""
496
- A pipeline module that receives incoming messages in ControlMessage format and calculates embeddings for
497
- supported document types.
547
+ A pipeline module that receives incoming messages in ControlMessage format
548
+ and calculates text embeddings for all supported content types.
498
549
499
550
Parameters
500
551
----------
@@ -519,56 +570,20 @@ def embed_extractions_fn(message: ControlMessage):
519
570
try :
520
571
task_props = message .remove_task ("embed" )
521
572
model_dump = task_props .model_dump ()
522
- embed_text = model_dump .get ("text" )
523
- embed_tables = model_dump .get ("tables" )
524
573
filter_errors = model_dump .get ("filter_errors" , False )
525
574
526
- logger .debug (f"Generating embeddings: text={ embed_text } , tables={ embed_tables } " )
527
- embedding_dataframes = []
528
- content_masks = []
529
-
530
- if embed_text :
531
- df_text , content_mask = _generate_embeddings (
532
- message ,
533
- ContentTypeEnum .TEXT ,
534
- event_loop ,
535
- validated_config .batch_size ,
536
- validated_config .api_key ,
537
- validated_config .embedding_nim_endpoint ,
538
- validated_config .embedding_model ,
539
- validated_config .encoding_format ,
540
- validated_config .input_type ,
541
- validated_config .truncate ,
542
- filter_errors ,
543
- )
544
- if df_text is not None :
545
- embedding_dataframes .append (df_text )
546
- content_masks .append (content_mask )
547
-
548
- if embed_tables :
549
- df_tables , table_mask = _generate_embeddings (
550
- message ,
551
- ContentTypeEnum .STRUCTURED ,
552
- event_loop ,
553
- validated_config .batch_size ,
554
- validated_config .api_key ,
555
- validated_config .embedding_nim_endpoint ,
556
- validated_config .embedding_model ,
557
- validated_config .encoding_format ,
558
- validated_config .input_type ,
559
- validated_config .truncate ,
560
- filter_errors ,
561
- )
562
- if df_tables is not None :
563
- embedding_dataframes .append (df_tables )
564
- content_masks .append (table_mask )
565
-
566
- if len (content_masks ) == 0 :
567
- return message
568
-
569
- message = _concatenate_extractions (message , embedding_dataframes , content_masks )
570
-
571
- return message
575
+ return _generate_embeddings (
576
+ message ,
577
+ event_loop ,
578
+ validated_config .batch_size ,
579
+ validated_config .api_key ,
580
+ validated_config .embedding_nim_endpoint ,
581
+ validated_config .embedding_model ,
582
+ validated_config .encoding_format ,
583
+ validated_config .input_type ,
584
+ validated_config .truncate ,
585
+ filter_errors ,
586
+ )
572
587
573
588
except Exception as e :
574
589
traceback .print_exc ()
0 commit comments